Skip to content

Commit b1509cf

Browse files
committed
refactor: generate curriculum plot from df_session
1 parent fe719be commit b1509cf

File tree

3 files changed

+28
-76
lines changed

3 files changed

+28
-76
lines changed

code/Home.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -348,37 +348,12 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None, if_load_s
348348
st.session_state[f'df_selected_from_{source}'] = pd.DataFrame(columns=['subject_id', 'session'])
349349

350350
# Load autotrain
351-
auto_train_manager, curriculum_manager = load_auto_train()
352-
st.session_state.auto_train_manager = auto_train_manager
351+
_, curriculum_manager = load_auto_train()
353352
st.session_state.curriculum_manager = curriculum_manager
354353

355354
# Some ad-hoc modifications on df_sessions
356355
_df = st.session_state.df['sessions_main'].copy()
357356

358-
# -- overwrite the `if_stage_overriden_by_trainer`
359-
# Previously it was set to True if the trainer changes stage during a session.
360-
# But it is more informative to define it as whether the trainer has overridden the curriculum.
361-
# In other words, it is set to True only when stage_suggested ~= stage_actual, as defined in the autotrain curriculum.
362-
_df.drop(columns=['if_overriden_by_trainer'], inplace=True)
363-
tmp_auto_train = (
364-
auto_train_manager.df_manager.query("if_closed_loop == True")[
365-
[
366-
"subject_id",
367-
"session_date",
368-
"current_stage_suggested",
369-
"if_stage_overriden_by_trainer",
370-
]
371-
]
372-
.copy()
373-
.drop_duplicates(subset=["subject_id", "session_date"], keep="first")
374-
)
375-
tmp_auto_train["session_date"] = pd.to_datetime(tmp_auto_train["session_date"])
376-
_df = _df.merge(
377-
tmp_auto_train,
378-
on=["subject_id", "session_date"],
379-
how='left',
380-
)
381-
382357
# --- Load data from docDB ---
383358
if_load_docDb = if_load_docDB_override if if_load_docDB_override is not None else (
384359
st.query_params['if_load_docDB'].lower() == 'true'

code/util/plot_autotrain_manager.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,42 +32,17 @@ def plot_manager_all_progress_bokeh_source(
3232
filtered_session_ids=None,
3333
):
3434
# --- Prepare data ---
35-
manager = st.session_state.auto_train_manager
36-
df_manager = manager.df_manager.sort_values(
35+
# Now we already merged full curriculum info in the master df_session by the result access api
36+
# manager = st.session_state.auto_train_manager
37+
38+
df_to_draw = st.session_state.df['sessions_main'].sort_values(
3739
by=["subject_id", "session"],
3840
ascending=[sort_order == "ascending", False]
39-
)
41+
).copy()
4042

41-
if not len(df_manager):
43+
if not len(df_to_draw):
4244
return None
4345

44-
# Metadata merge from df_master
45-
df_tmp_rig_trainer = st.session_state.df["sessions_main"][
46-
["subject_id", "session_date", "session", "rig", "trainer", "PI", "nwb_suffix",
47-
"foraging_eff_random_seed", "finished_trials", "finished_rate",
48-
"task", "curriculum_name", "curriculum_version", "current_stage_actual"]
49-
]
50-
df_tmp_rig_trainer["session_date"] = df_tmp_rig_trainer["session_date"].astype(str)
51-
52-
df_to_draw = (
53-
df_manager.drop_duplicates(
54-
subset=["subject_id", "session_date"], keep="last"
55-
) # Duplicte sessions in the autotrain due to pipeline issues
56-
.drop(
57-
columns=[
58-
"session",
59-
"task",
60-
"foraging_efficiency",
61-
"finished_trials",
62-
]
63-
) # df_master has higher priority in session numbers
64-
.merge(
65-
df_tmp_rig_trainer.query(f"current_stage_actual != 'None'"),
66-
on=["subject_id", "session_date"],
67-
how="right",
68-
)
69-
)
70-
7146
# If use_filtered_data, filter the data
7247
if if_use_filtered_data:
7348
df_to_draw = df_to_draw.merge(
@@ -76,11 +51,6 @@ def plot_manager_all_progress_bokeh_source(
7651
how="inner",
7752
)
7853

79-
# Correct df_manager missing sessions (df_manager has higher priority in curriculum-related fields)
80-
df_to_draw["curriculum_name"] = df_to_draw["curriculum_name_x"].fillna(df_to_draw["curriculum_name_y"])
81-
df_to_draw["curriculum_version"] = df_to_draw["curriculum_version_x"].fillna(df_to_draw["curriculum_version_y"])
82-
df_to_draw["current_stage_actual"] = df_to_draw["current_stage_actual_x"].fillna(df_to_draw["current_stage_actual_y"])
83-
8454
df_to_draw["color"] = df_to_draw["current_stage_actual"].map(stage_color_mapper)
8555
df_to_draw["edge_color"] = ( # Use grey edge to indicate stage without suggestion
8656
df_to_draw["current_stage_suggested"].map(stage_color_mapper).fillna("#d3d3d3")

code/util/streamlit.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -774,8 +774,6 @@ def _add_download_filtered_session():
774774

775775
def add_auto_train_manager():
776776

777-
df_training_manager = st.session_state.auto_train_manager.df_manager
778-
779777
# -- Show plotly chart --
780778
cols = st.columns([1, 1, 1, 0.7, 0.7, 1.5, 1.5, 2])
781779
options = ["date", "session", "relative_date"]
@@ -856,7 +854,7 @@ def add_auto_train_manager():
856854
else None
857855
),
858856
)
859-
857+
860858
if fig_auto_train is None:
861859
st.markdown("### In the filtered sessions, no AutoTrain history to show!")
862860
return
@@ -878,20 +876,29 @@ def add_auto_train_manager():
878876
st.write(data_df.iloc[indices])
879877

880878
# -- Show dataframe --
879+
df_training_manager = st.session_state.df_session_filtered[
880+
[
881+
"subject_id",
882+
"session_date",
883+
"session",
884+
"curriculum_name",
885+
"curriculum_version",
886+
"curriculum_schema_version",
887+
"current_stage_suggested",
888+
"current_stage_actual",
889+
"session_at_current_stage",
890+
"if_overriden_by_trainer",
891+
"foraging_efficiency",
892+
"finished_trials",
893+
"decision",
894+
"next_stage_suggested",
895+
]
896+
]
897+
881898
# only show filtered subject
882899
df_training_manager = df_training_manager[df_training_manager['subject_id'].isin(
883900
st.session_state.df_session_filtered['subject_id'].unique().astype(str))]
884901

885-
# reorder columns
886-
df_training_manager = df_training_manager[['subject_id', 'session_date', 'session',
887-
'curriculum_name', 'curriculum_version', 'curriculum_schema_version',
888-
'current_stage_suggested', 'current_stage_actual',
889-
'session_at_current_stage',
890-
'if_closed_loop', 'if_stage_overriden_by_trainer',
891-
'foraging_efficiency', 'finished_trials',
892-
'decision', 'next_stage_suggested'
893-
]]
894-
895902
with st.expander('Automatic training manager', expanded=False):
896903
st.dataframe(df_training_manager, height=3000)
897904

@@ -1195,4 +1202,4 @@ def add_download_plotly_as_svg(fig, file_name="plot.svg"):
11951202
data=svg_file,
11961203
file_name=file_name.replace(".svg", "") + ".svg",
11971204
mime="image/svg+xml"
1198-
)
1205+
)

0 commit comments

Comments
 (0)