Skip to content

Commit 0631d52

Browse files
authored
Merge pull request #117 from AllenNeuralDynamics/han_optimize_autotrain
Multiple Improvements on autotrain
2 parents 3b16bad + 3522b9b commit 0631d52

File tree

7 files changed

+128
-124
lines changed

7 files changed

+128
-124
lines changed

code/Home.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def show_curriculums():
302302

303303

304304
# ------- Layout starts here -------- #
305-
def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
305+
def init(if_load_bpod_data_override=None, if_load_docDB_override=None, if_load_sessions_older_than_6_month_override=None):
306306

307307
# Clear specific session state and all filters
308308
for key in st.session_state:
@@ -319,9 +319,25 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
319319
if 'if_load_bpod_sessions' in st.session_state
320320
else False)
321321
st.session_state.bpod_loaded = _if_load_bpod
322-
322+
323+
_if_load_sessions_older_than_6_month = (
324+
if_load_sessions_older_than_6_month_override
325+
if if_load_sessions_older_than_6_month_override is not None
326+
else (
327+
st.query_params["if_load_sessions_older_than_6_month"].lower() == "true"
328+
if "if_load_sessions_older_than_6_month" in st.query_params
329+
else (
330+
st.session_state.if_load_sessions_older_than_6_month
331+
if "if_load_sessions_older_than_6_month" in st.session_state
332+
else False
333+
)
334+
)
335+
)
336+
323337
# --- Load data using aind-analysis-arch-result-access ---
324-
df_han = get_session_table(if_load_bpod=_if_load_bpod)
338+
# Convert boolean to months: if True, load all sessions (None), if False, load only recent 6 months
339+
only_recent_n_month = None if _if_load_sessions_older_than_6_month else 6
340+
df_han = get_session_table(if_load_bpod=_if_load_bpod, only_recent_n_month=only_recent_n_month)
325341
df = {'sessions_main': df_han} # put it in df['session_main'] for backward compatibility
326342

327343
if not len(df):
@@ -332,38 +348,12 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
332348
st.session_state[f'df_selected_from_{source}'] = pd.DataFrame(columns=['subject_id', 'session'])
333349

334350
# Load autotrain
335-
auto_train_manager, curriculum_manager = load_auto_train()
336-
st.session_state.auto_train_manager = auto_train_manager
351+
_, curriculum_manager = load_auto_train()
337352
st.session_state.curriculum_manager = curriculum_manager
338353

339354
# Some ad-hoc modifications on df_sessions
340355
_df = st.session_state.df['sessions_main'].copy()
341356

342-
343-
# -- overwrite the `if_stage_overriden_by_trainer`
344-
# Previously it was set to True if the trainer changes stage during a session.
345-
# But it is more informative to define it as whether the trainer has overridden the curriculum.
346-
# In other words, it is set to True only when stage_suggested ~= stage_actual, as defined in the autotrain curriculum.
347-
_df.drop(columns=['if_overriden_by_trainer'], inplace=True)
348-
tmp_auto_train = (
349-
auto_train_manager.df_manager.query("if_closed_loop == True")[
350-
[
351-
"subject_id",
352-
"session_date",
353-
"current_stage_suggested",
354-
"if_stage_overriden_by_trainer",
355-
]
356-
]
357-
.copy()
358-
.drop_duplicates(subset=["subject_id", "session_date"], keep="first")
359-
)
360-
tmp_auto_train["session_date"] = pd.to_datetime(tmp_auto_train["session_date"])
361-
_df = _df.merge(
362-
tmp_auto_train,
363-
on=["subject_id", "session_date"],
364-
how='left',
365-
)
366-
367357
# --- Load data from docDB ---
368358
if_load_docDb = if_load_docDB_override if if_load_docDB_override is not None else (
369359
st.query_params['if_load_docDB'].lower() == 'true'
@@ -439,9 +429,6 @@ def app():
439429
# -- 1. unit dataframe --
440430

441431
cols = st.columns([4, 4, 4, 1])
442-
cols[0].markdown(f'### Filter the sessions on the sidebar\n'
443-
f'##### {len(st.session_state.df_session_filtered)} sessions, '
444-
f'{len(st.session_state.df_session_filtered.subject_id.unique())} mice filtered')
445432

446433
with cols[0].expander(':bulb: Get the master session table by code', expanded=False):
447434
st.code(f'''
@@ -453,6 +440,12 @@ def app():
453440

454441
with cols[1]:
455442
with st.form(key='load_settings', clear_on_submit=False):
443+
if_load_sessions_older_than_6_month = checkbox_wrapper_for_url_query(
444+
st_prefix=st,
445+
label='Include sessions older than 6 months (reload after change)',
446+
key='if_load_sessions_older_than_6_month',
447+
default=False,
448+
)
456449
if_load_bpod_sessions = checkbox_wrapper_for_url_query(
457450
st_prefix=st,
458451
label='Include old Bpod sessions (reload after change)',
@@ -472,6 +465,12 @@ def app():
472465
sync_session_state_to_URL()
473466
init()
474467
st.rerun() # Reload the page to apply the changes
468+
469+
cols[0].markdown(f'### Filter the sessions on the sidebar\n' +
470+
f'##### {len(st.session_state.df_session_filtered)} sessions, ' +
471+
f'{len(st.session_state.df_session_filtered.subject_id.unique())} mice filtered' +
472+
(f' (recent 6 months only)' if not st.session_state.if_load_sessions_older_than_6_month else '')
473+
)
475474

476475
table_height = slider_wrapper_for_url_query(st_prefix=cols[-1],
477476
label='Table height',
@@ -696,4 +695,4 @@ def add_main_tabs():
696695
st.markdown('#### 1. Reload the page')
697696
st.markdown('#### 2. Click this original URL https://foraging-behavior-browser.allenneuraldynamics-test.org/')
698697
st.markdown('#### 3. Report your bug here: https://github.com/AllenNeuralDynamics/foraging-behavior-browser/issues (paste your URL and screenshoots)')
699-
raise e
698+
raise e

code/pages/0_Data inventory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,6 @@ def add_venn_diagrms(df_merged):
559559
# Share the same master df as the Home page
560560
if "df" not in st.session_state or "sessions_main" not in st.session_state.df.keys() or not st.session_state.bpod_loaded:
561561
st.spinner("Loading data from Han temp pipeline...")
562-
init(if_load_docDB_override=False, if_load_bpod_data_override=True)
562+
init(if_load_docDB_override=False, if_load_bpod_data_override=True, if_load_sessions_older_than_6_month_override=True)
563563

564564
app()

code/util/aws_s3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def load_auto_train():
6262
curriculum_manager = CurriculumManager(
6363
saved_curriculums_on_s3=dict(
6464
bucket='aind-behavior-data',
65-
root='foraging_auto_training/saved_curriculums/'
65+
root='foraging_nwb_bonsai_processed/foraging_auto_training/saved_curriculums/'
6666
),
6767
saved_curriculums_local=os.path.expanduser('~/curriculum_manager/'),
6868
)
@@ -72,7 +72,7 @@ def load_auto_train():
7272
root='foraging_nwb_bonsai_processed/',
7373
file_name='df_sessions.pkl'),
7474
df_manager_root_on_s3=dict(bucket='aind-behavior-data',
75-
root='foraging_auto_training/')
75+
root='foraging_nwb_bonsai_processed/foraging_auto_training/')
7676
)
7777

7878
_df = auto_train_manager.df_manager.copy().rename(

code/util/plot_autotrain_manager.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,42 +32,17 @@ def plot_manager_all_progress_bokeh_source(
3232
filtered_session_ids=None,
3333
):
3434
# --- Prepare data ---
35-
manager = st.session_state.auto_train_manager
36-
df_manager = manager.df_manager.sort_values(
35+
# Now we already merged full curriculum info in the master df_session by the result access api
36+
# manager = st.session_state.auto_train_manager
37+
38+
df_to_draw = st.session_state.df['sessions_main'].sort_values(
3739
by=["subject_id", "session"],
3840
ascending=[sort_order == "ascending", False]
39-
)
41+
).copy()
4042

41-
if not len(df_manager):
43+
if not len(df_to_draw):
4244
return None
4345

44-
# Metadata merge from df_master
45-
df_tmp_rig_trainer = st.session_state.df["sessions_main"][
46-
["subject_id", "session_date", "session", "rig", "trainer", "PI", "nwb_suffix",
47-
"foraging_eff_random_seed", "finished_trials", "finished_rate",
48-
"task", "curriculum_name", "curriculum_version", "current_stage_actual"]
49-
]
50-
df_tmp_rig_trainer["session_date"] = df_tmp_rig_trainer["session_date"].astype(str)
51-
52-
df_to_draw = (
53-
df_manager.drop_duplicates(
54-
subset=["subject_id", "session_date"], keep="last"
55-
) # Duplicte sessions in the autotrain due to pipeline issues
56-
.drop(
57-
columns=[
58-
"session",
59-
"task",
60-
"foraging_efficiency",
61-
"finished_trials",
62-
]
63-
) # df_master has higher priority in session numbers
64-
.merge(
65-
df_tmp_rig_trainer.query(f"current_stage_actual != 'None'"),
66-
on=["subject_id", "session_date"],
67-
how="right",
68-
)
69-
)
70-
7146
# If use_filtered_data, filter the data
7247
if if_use_filtered_data:
7348
df_to_draw = df_to_draw.merge(
@@ -76,19 +51,17 @@ def plot_manager_all_progress_bokeh_source(
7651
how="inner",
7752
)
7853

79-
# Correct df_manager missing sessions (df_manager has higher priority in curriculum-related fields)
80-
df_to_draw["curriculum_name"] = df_to_draw["curriculum_name_x"].fillna(df_to_draw["curriculum_name_y"])
81-
df_to_draw["curriculum_version"] = df_to_draw["curriculum_version_x"].fillna(df_to_draw["curriculum_version_y"])
82-
df_to_draw["current_stage_actual"] = df_to_draw["current_stage_actual_x"].fillna(df_to_draw["current_stage_actual_y"])
83-
8454
df_to_draw["color"] = df_to_draw["current_stage_actual"].map(stage_color_mapper)
8555
df_to_draw["edge_color"] = ( # Use grey edge to indicate stage without suggestion
8656
df_to_draw["current_stage_suggested"].map(stage_color_mapper).fillna("#d3d3d3")
8757
)
58+
59+
# Convert session_date to string for URL generation
60+
df_to_draw["session_date_str"] = df_to_draw["session_date"].astype(str)
8861
df_to_draw["imgs_1"] = df_to_draw.apply(
8962
lambda x: get_s3_public_url(
9063
subject_id=x["subject_id"],
91-
session_date=x["session_date"],
64+
session_date=x["session_date_str"],
9265
nwb_suffix=x["nwb_suffix"],
9366
figure_suffix="choice_history.png",
9467
),
@@ -97,13 +70,17 @@ def plot_manager_all_progress_bokeh_source(
9770
df_to_draw["imgs_2"] = df_to_draw.apply(
9871
lambda x: get_s3_public_url(
9972
subject_id=x["subject_id"],
100-
session_date=x["session_date"],
73+
session_date=x["session_date_str"],
10174
nwb_suffix=x["nwb_suffix"],
10275
figure_suffix="logistic_regression_Su2022.png",
10376
),
10477
axis=1,
10578
)
10679
df_to_draw.round(3)
80+
81+
# --- Remove rows with NaN in color or edge_color ---
82+
# to fix a bug where non-normalized stages appears in the autotrain table
83+
df_to_draw = df_to_draw.dropna(subset=["color", "edge_color"])
10784

10885
# --- Filter recent days ---
10986
df_to_draw['session_date'] = pd.to_datetime(df_to_draw['session_date'])

code/util/streamlit.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -774,8 +774,6 @@ def _add_download_filtered_session():
774774

775775
def add_auto_train_manager():
776776

777-
df_training_manager = st.session_state.auto_train_manager.df_manager
778-
779777
# -- Show plotly chart --
780778
cols = st.columns([1, 1, 1, 0.7, 0.7, 1.5, 1.5, 2])
781779
options = ["date", "session", "relative_date"]
@@ -818,7 +816,11 @@ def add_auto_train_manager():
818816
),
819817
default=True,
820818
)
821-
only_filtered_effective = only_filtered and len(st.session_state.df_session_filtered) < len(st.session_state.df["sessions_main"])
819+
only_filtered_effective = (
820+
only_filtered
821+
and len(st.session_state.df_session_filtered) < len(st.session_state.df["sessions_main"])
822+
and len(st.session_state.df_session_filtered) < 500 # Only if filtered down to less than 500 sessions
823+
)
822824

823825
recent_months = slider_wrapper_for_url_query(cols[6],
824826
label="only recent months",
@@ -856,7 +858,7 @@ def add_auto_train_manager():
856858
else None
857859
),
858860
)
859-
861+
860862
if fig_auto_train is None:
861863
st.markdown("### In the filtered sessions, no AutoTrain history to show!")
862864
return
@@ -878,20 +880,29 @@ def add_auto_train_manager():
878880
st.write(data_df.iloc[indices])
879881

880882
# -- Show dataframe --
883+
df_training_manager = st.session_state.df_session_filtered[
884+
[
885+
"subject_id",
886+
"session_date",
887+
"session",
888+
"curriculum_name",
889+
"curriculum_version",
890+
"curriculum_schema_version",
891+
"current_stage_suggested",
892+
"current_stage_actual",
893+
"session_at_current_stage",
894+
"if_overriden_by_trainer",
895+
"foraging_eff",
896+
"finished_trials",
897+
"decision",
898+
"next_stage_suggested",
899+
]
900+
]
901+
881902
# only show filtered subject
882903
df_training_manager = df_training_manager[df_training_manager['subject_id'].isin(
883904
st.session_state.df_session_filtered['subject_id'].unique().astype(str))]
884905

885-
# reorder columns
886-
df_training_manager = df_training_manager[['subject_id', 'session_date', 'session',
887-
'curriculum_name', 'curriculum_version', 'curriculum_schema_version',
888-
'current_stage_suggested', 'current_stage_actual',
889-
'session_at_current_stage',
890-
'if_closed_loop', 'if_stage_overriden_by_trainer',
891-
'foraging_efficiency', 'finished_trials',
892-
'decision', 'next_stage_suggested'
893-
]]
894-
895906
with st.expander('Automatic training manager', expanded=False):
896907
st.dataframe(df_training_manager, height=3000)
897908

@@ -1195,4 +1206,4 @@ def add_download_plotly_as_svg(fig, file_name="plot.svg"):
11951206
data=svg_file,
11961207
file_name=file_name.replace(".svg", "") + ".svg",
11971208
mime="image/svg+xml"
1198-
)
1209+
)

environment/Dockerfile

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
11
# hash:sha256:51bda5f40316acb89ad85a82e996448f5a31d6f40b5b443e817e9b346eee2f67
22
ARG REGISTRY_HOST
3-
FROM $REGISTRY_HOST/codeocean/jupyterlab:3.6.1-miniconda4.12.0-python3.9-ubuntu20.04
3+
FROM $REGISTRY_HOST/codeocean/mambaforge3:23.1.0-4-python3.10.12-ubuntu22.04
44

55
ARG DEBIAN_FRONTEND=noninteractive
66

77
ARG GIT_ASKPASS
88
ARG GIT_ACCESS_TOKEN
99
COPY git-askpass /
1010

11-
RUN pip install -r https://raw.githubusercontent.com/AllenNeuralDynamics/foraging-behavior-browser/main/requirements.txt
12-
13-
ADD "https://github.com/coder/code-server/releases/download/v4.21.1/code-server-4.21.1-linux-amd64.tar.gz" /.code-server/code-server.tar.gz
14-
15-
RUN cd /.code-server \
16-
&& tar -xvf code-server.tar.gz \
17-
&& rm code-server.tar.gz \
18-
&& ln -s /.code-server/code-server-4.21.1-linux-amd64/bin/code-server /usr/bin/code-server
19-
11+
RUN pip install -r https://raw.githubusercontent.com/AllenNeuralDynamics/foraging-behavior-browser/main/requirements.txt --no-cache-dir
2012

2113
COPY postInstall /
2214
RUN /postInstall

0 commit comments

Comments
 (0)