Skip to content

Commit 926ebae

Browse files
authored
Merge pull request #113 from AllenNeuralDynamics/han_add_new_ephysromm_and_others
Add new ephysromm and other bug fixes
2 parents c4192fc + 2dfefff commit 926ebae

File tree

7 files changed

+112
-82
lines changed

7 files changed

+112
-82
lines changed

code/Home.py

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ def draw_session_plots(df_to_draw_session):
110110
except:
111111
date_str = key["session_date"].split("T")[0]
112112

113-
st.markdown(f'''<h5 style='text-align: center; color: orange;'>{key["subject_id"]} ({key["PI"]}), Session {int(key["session"])}, {date_str} '''
114-
f'''({key["trainer"]}@{key["data_source"]})''',
113+
st.markdown(f'''<h6 style='text-align: center; color: orange;'>{key["subject_id"]} ({key["PI"]}), {date_str}, Session {int(key["session"])}<br>'''
114+
f'''{key["trainer"]} @ {key["rig"]} ({key["data_source"]})''',
115115
unsafe_allow_html=True)
116+
116117
if len(st.session_state.session_plot_selected_draw_types) > 1: # more than one types, use the pre-defined layout
117118
for row, column_setting in enumerate(draw_type_layout_definition):
118119
rows.append(this_major_col.columns(column_setting))
@@ -133,7 +134,7 @@ def draw_session_plots(df_to_draw_session):
133134
my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))
134135

135136

136-
def session_plot_settings(df_selected, need_click=True):
137+
def session_plot_settings(df_selected_from_plotly=None, need_click=True):
137138
with st.form(key='session_plot_settings'):
138139
st.markdown('##### Show plots for individual sessions ')
139140
cols = st.columns([2, 6, 1])
@@ -146,10 +147,14 @@ def session_plot_settings(df_selected, need_click=True):
146147
default=session_plot_modes[0],
147148
key='session_plot_mode',
148149
)
149-
150-
n_session_to_draw = len(df_selected) \
151-
if 'selected from table or plot' in st.session_state.selected_draw_sessions \
152-
else len(st.session_state.df_session_filtered)
150+
151+
if "selected" in st.session_state.selected_draw_sessions:
152+
if df_selected_from_plotly is None: # Selected from dataframe
153+
df_to_draw_sessions = st.session_state.df_selected_from_dataframe
154+
else:
155+
df_to_draw_sessions = df_selected_from_plotly
156+
else: # all sessions filtered from sidebar
157+
df_to_draw_sessions = st.session_state.df_session_filtered
153158

154159
_ = number_input_wrapper_for_url_query(
155160
st_prefix=cols[2],
@@ -177,20 +182,20 @@ def session_plot_settings(df_selected, need_click=True):
177182
key='session_plot_selected_draw_types',
178183
)
179184

180-
cols[0].markdown(f'{n_session_to_draw} sessions to draw')
185+
cols[0].markdown(f'{len(df_to_draw_sessions)} sessions to draw')
181186
draw_it_now_override = cols[2].checkbox('Auto show', value=not need_click, disabled=not need_click)
182187
submitted = cols[0].form_submit_button(
183188
"Update settings", type="primary"
184189
)
185190

186191
if not need_click:
187-
return True
192+
return True, df_to_draw_sessions
188193

189194
if draw_it_now_override:
190-
return True
195+
return True, df_to_draw_sessions
191196

192-
draw_it = st.button(f'Show {n_session_to_draw} sessions!', use_container_width=False, type="primary")
193-
return draw_it
197+
draw_it = st.button(f'Show {len(df_to_draw_sessions)} sessions!', use_container_width=False, type="primary")
198+
return draw_it, df_to_draw_sessions
194199

195200

196201
def plot_x_y_session():
@@ -335,12 +340,12 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
335340
st.session_state.curriculum_manager = curriculum_manager
336341

337342
# Some ad-hoc modifications on df_sessions
338-
_df = st.session_state.df['sessions_main'] # temporary df alias
343+
_df = st.session_state.df['sessions_main'].copy() # temporary df alias
339344

340345
_df.columns = _df.columns.get_level_values(1)
341346
_df.sort_values(['session_start_time'], ascending=False, inplace=True)
342347
_df['session_start_time'] = _df['session_start_time'].astype(str) # Turn to string
343-
_df = _df.reset_index().query('subject_id != "0"')
348+
_df = _df.reset_index()
344349

345350
# Handle mouse and user name
346351
if 'bpod_backup_h2o' in _df.columns:
@@ -361,7 +366,6 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
361366
(_df["trainer"].isin(_df["PI"]) | _df["trainer"].isin(["Han Hou", "Marton Rozsa"])),
362367
"trainer"
363368
] # Fill in PI with trainer if PI is missing and the trainer was ever a PI
364-
365369

366370
# Add data source (Room + Hardware etc)
367371
_df[['institute', 'rig_type', 'room', 'hardware', 'data_source']] = _df['rig'].apply(lambda x: pd.Series(get_data_source(x)))
@@ -373,10 +377,10 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
373377
# Remove invalid subject_id
374378
_df = _df[(999999 > _df["subject_id"].astype(int))
375379
& (_df["subject_id"].astype(int) > 300000)]
376-
380+
377381
# Remove zero finished trials
378382
_df = _df[_df['finished_trials'] > 0]
379-
383+
380384
# Remove abnormal values
381385
_df.loc[_df['weight_after'] > 100,
382386
['weight_after', 'weight_after_ratio', 'water_in_session_total', 'water_after_session', 'water_day_total']
@@ -411,7 +415,28 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
411415

412416
# last day's total water
413417
_df['water_day_total_last_session'] = _df.groupby('subject_id')['water_day_total'].shift(1)
414-
_df['water_after_session_last_session'] = _df.groupby('subject_id')['water_after_session'].shift(1)
418+
_df['water_after_session_last_session'] = _df.groupby('subject_id')['water_after_session'].shift(1)
419+
420+
421+
# -- overwrite the `if_stage_overriden_by_trainer`
422+
# Previously it was set to True if the trainer changes stage during a session.
423+
# But it is more informative to define it as whether the trainer has overridden the curriculum.
424+
# In other words, it is set to True only when stage_suggested ~= stage_actual, as defined in the autotrain curriculum.
425+
_df.drop(columns=['if_overriden_by_trainer'], inplace=True)
426+
tmp_auto_train = auto_train_manager.df_manager.query('if_closed_loop == True')[
427+
[
428+
"subject_id",
429+
"session_date",
430+
"current_stage_suggested",
431+
"if_stage_overriden_by_trainer",
432+
]
433+
].copy()
434+
tmp_auto_train['session_date'] = pd.to_datetime(tmp_auto_train['session_date'])
435+
_df = _df.merge(
436+
tmp_auto_train,
437+
on=["subject_id", "session_date"],
438+
how='left',
439+
)
415440

416441
# fill nan for autotrain fields
417442
filled_values = {'curriculum_name': 'None',
@@ -420,7 +445,6 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
420445
'current_stage_actual': 'None',
421446
'has_video': False,
422447
'has_ephys': False,
423-
'if_overriden_by_trainer': False,
424448
}
425449
_df.fillna(filled_values, inplace=True)
426450

@@ -436,9 +460,6 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
436460
# drop 'bpod_backup_' columns
437461
_df.drop([col for col in _df.columns if 'bpod_backup_' in col], axis=1, inplace=True)
438462

439-
# fix if_overriden_by_trainer
440-
_df['if_overriden_by_trainer'] = _df['if_overriden_by_trainer'].astype(bool)
441-
442463
# _df = _df.merge(
443464
# diff_relative_weight_next_day, how='left', on=['subject_id', 'session'])
444465

@@ -567,7 +588,7 @@ def app():
567588
return
568589

569590
aggrid_outputs = aggrid_interactive_table_session(
570-
df=st.session_state.df_session_filtered,
591+
df=st.session_state.df_session_filtered.round(3),
571592
table_height=table_height,
572593
)
573594

@@ -604,13 +625,7 @@ def add_main_tabs():
604625
# Add session_plot_setting
605626
with st.columns([1])[0]:
606627
st.markdown("***")
607-
if_draw_all_sessions = session_plot_settings(df_selected_from_plotly)
608-
609-
df_to_draw_sessions = (
610-
df_selected_from_plotly
611-
if "selected" in st.session_state.get("selected_draw_sessions", "sessions selected from table or plot")
612-
else st.session_state.df_session_filtered
613-
)
628+
if_draw_all_sessions, df_to_draw_sessions = session_plot_settings(df_selected_from_plotly=df_selected_from_plotly, need_click=True)
614629

615630
if if_draw_all_sessions and len(df_to_draw_sessions):
616631
draw_session_plots(df_to_draw_sessions)
@@ -650,13 +665,8 @@ def add_main_tabs():
650665
with placeholder:
651666
cols = st.columns([1])
652667
with cols[0]:
653-
df_to_draw_sessions = (
654-
st.session_state.df_selected_from_dataframe
655-
if "selected" in st.session_state.get("selected_draw_sessions", "sessions selected from table or plot")
656-
else st.session_state.df_session_filtered
657-
)
658-
if_draw_all_sessions = session_plot_settings(
659-
df_to_draw_sessions, need_click=False
668+
if_draw_all_sessions, df_to_draw_sessions = session_plot_settings(
669+
df_selected_from_plotly=None, need_click=False
660670
)
661671

662672
if if_draw_all_sessions and len(df_to_draw_sessions):
@@ -769,10 +779,17 @@ def add_main_tabs():
769779
# st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)
770780

771781
if __name__ == "__main__":
772-
ok = True
773-
if 'df' not in st.session_state or 'sessions_main' not in st.session_state.df.keys():
774-
ok = init()
775-
776-
if ok:
777-
app()
778-
pass
782+
try:
783+
ok = True
784+
if 'df' not in st.session_state or 'sessions_main' not in st.session_state.df.keys():
785+
ok = init()
786+
787+
if ok:
788+
app()
789+
pass
790+
except Exception as e:
791+
st.markdown('# Something went wrong! :scream: ')
792+
st.markdown('## :bulb: Please follow these steps to troubleshoot:')
793+
st.markdown('#### 1. Reload the page')
794+
st.markdown('#### 2. Click this original URL https://foraging-behavior-browser.allenneuraldynamics-test.org/')
795+
st.markdown('#### 3. Report your bug here: https://github.com/AllenNeuralDynamics/foraging-behavior-browser/issues (paste your URL and screenshoots)')

code/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__ver__ = 'v3.0.1 beta'
1+
__ver__ = 'v3.0.2'

code/util/aws_s3.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@ def load_auto_train():
7474
df_manager_root_on_s3=dict(bucket='aind-behavior-data',
7575
root='foraging_auto_training/')
7676
)
77-
78-
_df = auto_train_manager.df_manager.copy()
77+
78+
_df = auto_train_manager.df_manager.copy().rename(
79+
columns={"if_overriden_by_trainer": "if_stage_overriden_by_trainer"}
80+
)
7981
# Remove invalid subject_id
8082
_df = _df[(999999 > _df["subject_id"].astype(int))
8183
& (_df["subject_id"].astype(int) > 300000)]
8284
auto_train_manager.df_manager = _df
83-
85+
8486
return auto_train_manager, curriculum_manager
8587

8688
def draw_session_plots_quick_preview(df_to_draw_session):
@@ -95,8 +97,8 @@ def draw_session_plots_quick_preview(df_to_draw_session):
9597
except:
9698
date_str = key["session_date"].split("T")[0]
9799

98-
st.markdown(f'''<h5 style='text-align: center; color: orange;'>{key["subject_id"]} ({key["PI"]}), Session {int(key["session"])}, {date_str} '''
99-
f'''({key["trainer"]}@{key["data_source"]})''',
100+
st.markdown(f'''<h6 style='text-align: center; color: orange;'>{key["subject_id"]} ({key["PI"]}), {date_str}, Session {int(key["session"])}<br>'''
101+
f'''{key["trainer"]} @ {key["rig"]} ({key["data_source"]})''',
100102
unsafe_allow_html=True)
101103

102104
rows = []

code/util/reformat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def get_data_source(rig):
6262
room = '446'
6363
elif '323' in rig:
6464
room = '323'
65+
elif '322' in rig:
66+
room = '322'
6567
elif rig_type == 'ephys':
6668
room = '323'
6769
else:

code/util/streamlit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@ def add_auto_train_manager():
886886
'curriculum_name', 'curriculum_version', 'curriculum_schema_version',
887887
'current_stage_suggested', 'current_stage_actual',
888888
'session_at_current_stage',
889-
'if_closed_loop', 'if_overriden_by_trainer',
889+
'if_closed_loop', 'if_stage_overriden_by_trainer',
890890
'foraging_efficiency', 'finished_trials',
891891
'decision', 'next_stage_suggested'
892892
]]

code/util/url_query_helper.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs):
9999
)
100100

101101
def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, default_override=False, **kwargs):
102+
"""selectbox wrapper for url query
103+
default : could be either a value in options or the index
104+
default_override : if True, always use default, otherwise, session_state or query_params has higher priority
105+
"""
102106
# If default_override, use default. Otherwise, session_state or query_params has higher priority
103107
if not default_override:
104108
default = (
@@ -108,11 +112,16 @@ def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, def
108112
if key in st.query_params and st.query_params[key] in options
109113
else default
110114
)
115+
116+
if type(default) is int and default not in options:
117+
index = default # allow user to use default as index
118+
else:
119+
index = options.index(default)
111120

112121
return st_prefix.selectbox(
113122
label,
114123
options=options,
115-
index=options.index(default),
124+
index=index,
116125
key=key,
117126
**kwargs,
118127
)

0 commit comments

Comments
 (0)