Skip to content

Commit aa2400c

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Add flag to db to dataframe method to capture extra data (facebookresearch#720)
Summary: Pull Request resolved: facebookresearch#720 Experiment to dataframe utility by default ignores extra info. Add an extra argument to capture that data in the dataframe. Reviewed By: adellari Differential Revision: D72405413 fbshipit-source-id: 71f755c11694afa2922e2b005ff85fd9c1692c32
1 parent 21793a9 commit aa2400c

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

aepsych/database/db.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,11 @@ def get_stimuli_per_trial(master_id):
564564

565565
return pd.DataFrame(exp_dict)
566566

567-
def get_data_frame(self) -> pd.DataFrame:
567+
def get_data_frame(self, include_extra_data: bool = False) -> pd.DataFrame:
568568
"""Converts parameter and outcome data in the database into a pandas dataframe.
569-
569+
Args:
570+
include_extra_data (bool): Whether to include columns for extra data from
571+
the raw table. Defaults to False.
570572
Returns:
571573
pandas.Dataframe: The dataframe containing the parameter and outcome data.
572574
"""
@@ -587,6 +589,15 @@ def get_data_frame(self) -> pd.DataFrame:
587589
row.update({par.param_name: par.param_value for par in pars})
588590
row.update({out.outcome_name: out.outcome_value for out in outs})
589591

592+
extra_data = pars[0].parent.extra_data
593+
if include_extra_data and extra_data is not None:
594+
if isinstance(extra_data, str):
595+
extra_data_dict = json.loads(extra_data)
596+
elif isinstance(extra_data, dict):
597+
extra_data_dict = extra_data
598+
599+
row.update(extra_data_dict)
600+
590601
rows.append(row)
591602

592603
df = pd.DataFrame(rows)

tests/test_db.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import aepsych.config as configuration
1818
import aepsych.database.db as db
1919
import aepsych.database.tables as tables
20+
import numpy as np
2021
import pandas as pd
2122
import sqlalchemy
2223

@@ -631,6 +632,31 @@ def test_get_dataframe(self):
631632
n += len(self.data.get_raw_for(rec.unique_id))
632633
self.assertEqual(n, len(df))
633634

635+
def test_get_dataframe_extra_data(self):
636+
current_path = Path(os.path.abspath(__file__)).parent
637+
db_path = current_path
638+
db_path = db_path.joinpath("test_databases/extra_info.db")
639+
640+
# Make a copy of the database
641+
dst_db_path = Path("./{}.db".format(str(uuid.uuid4().hex)))
642+
shutil.copy(db_path, dst_db_path)
643+
644+
time.sleep(0.1)
645+
self.assertTrue(dst_db_path.is_file())
646+
647+
self.data = db.Database(dst_db_path)
648+
649+
df = self.data.get_data_frame(include_extra_data=True)
650+
651+
# extra_info db is ragged, so check that we have the right columns
652+
self.assertIn("extra", df.columns)
653+
self.assertIn("additional", df.columns)
654+
self.assertIn("trial_number", df.columns)
655+
656+
# First row is missing the extra data, so check that it is missing
657+
np.testing.assert_equal(df.iloc[0]["trial_number"], np.nan)
658+
self.assertEqual(df.iloc[1]["trial_number"], 2)
659+
634660

635661
if __name__ == "__main__":
636662
unittest.main()

0 commit comments

Comments
 (0)