Skip to content

Commit 3bf4322

Browse files
committed
minor improvements
1 parent 8d7e5b2 commit 3bf4322

File tree

2 files changed

+59
-43
lines changed

2 files changed

+59
-43
lines changed

code/Home.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,7 @@ def draw_session_plots(keys_to_draw_session):
123123
[1.5, 1], # columns in the second row
124124
]
125125

126-
draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix
127-
(0, 0), # location (row_idx, column_idx)
128-
dict(other_patterns=['model_best', 'model_None'])),
129-
'2. Lick times': ('lick_psth',
130-
(1, 0),
131-
{}),
132-
'3. Logistic regression on choice': ('logistic_regression',
133-
(1, 1),
134-
dict(crop=(0, 0, 1200, 2000))),
135-
'4. Win-stay-lose-shift prob.': ('wsls',
136-
(1, 1),
137-
dict(crop=(0, 0, 1200, 600))),
138-
'5. Linear regression on RT': ('linear_regression_rt',
139-
(1, 0),
140-
dict()),
141-
}
142-
143-
cols_option = st.columns([3, 0.5, 1])
144-
selected_draw_types = cols_option[0].multiselect('Which plot(s) to draw?', draw_type_mapper.keys(), default=draw_type_mapper.keys())
145-
num_cols = cols_option[1].number_input('Number of columns', 1, 10, 2)
126+
# cols_option = st.columns([3, 0.5, 1])
146127
container_session_all_in_one = st.container()
147128

148129
with container_session_all_in_one:
@@ -152,30 +133,30 @@ def draw_session_plots(keys_to_draw_session):
152133
st.write(f'Loading selected {len(keys_to_draw_session)} sessions...')
153134
my_bar = st.columns((1, 7))[0].progress(0)
154135

155-
major_cols = st.columns([1] * num_cols)
136+
major_cols = st.columns([1] * st.session_state.num_cols)
156137

157138
if not isinstance(keys_to_draw_session, list): # Turn dataframe to list, if necessary
158139
keys_to_draw_session = keys_to_draw_session.to_dict(orient='records')
159140

160141
for i, key in enumerate(keys_to_draw_session):
161-
this_major_col = major_cols[i % num_cols]
142+
this_major_col = major_cols[i % st.session_state.num_cols]
162143

163144
# setting up layout for each session
164145
rows = []
165146
with this_major_col:
166-
st.markdown(f'''<h3 style='text-align: center; color: blue;'>{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''',
147+
st.markdown(f'''<h3 style='text-align: center; color: orange;'>{key["h2o"]}, Session {key["session"]}, {key["session_date"].split("T")[0]}''',
167148
unsafe_allow_html=True)
168-
if len(selected_draw_types) > 1: # more than one types, use the pre-defined layout
149+
if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout
169150
for row, column_setting in enumerate(layout_definition):
170151
rows.append(this_major_col.columns(column_setting))
171152
else: # else, put it in the whole column
172153
rows = this_major_col.columns([1])
173154
st.markdown("---")
174155

175-
for draw_type in draw_type_mapper:
176-
if draw_type not in selected_draw_types: continue # To keep the draw order defined by draw_type_mapper
177-
prefix, position, setting = draw_type_mapper[draw_type]
178-
this_col = rows[position[0]][position[1]] if len(selected_draw_types) > 1 else rows[0]
156+
for draw_type in st.session_state.draw_type_mapper:
157+
if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper
158+
prefix, position, setting = st.session_state.draw_type_mapper[draw_type]
159+
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
179160
show_img_by_key_and_prefix(key,
180161
column=this_col,
181162
prefix=prefix,
@@ -350,7 +331,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):
350331

351332
fig.update_layout(
352333
# width=1300,
353-
height=850,
334+
# height=850,
354335
xaxis_title=x_name,
355336
yaxis_title=y_name,
356337
# xaxis_range=[0, min(100, df[x_name].max())],
@@ -361,9 +342,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):
361342
)
362343

363344
# st.plotly_chart(fig)
364-
selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870)
345+
selected_sessions_from_plot = plotly_events(fig, click_event=True, hover_event=False, select_event=True, override_height=870, override_width=1400)
365346

366347
return selected_sessions_from_plot
348+
349+
def session_plot_settings():
350+
st.session_state.selected_draw_types = st.multiselect('Which plot(s) to draw?', st.session_state.draw_type_mapper.keys(), default=st.session_state.draw_type_mapper.keys())
351+
st.session_state.num_cols = st.columns([1, 3])[0].number_input('Number of columns', 1, 10, 2)
352+
367353

368354

369355
def population_analysis():
@@ -393,7 +379,19 @@ def population_analysis():
393379

394380
smooth_factor = s_cols[0].slider('Smooth factor', 1, 20, 5, disabled=not ((if_aggr_each_group and aggr_method_group=='lowess')
395381
or (if_aggr_all and aggr_method_all=='lowess')))
396-
382+
for i in range(7): st.write('\n')
383+
384+
st.markdown("***")
385+
st.markdown('##### Click or box/lasso select session(s) from the plots to draw 👉')
386+
session_plot_settings()
387+
388+
with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=False):
389+
if st.button('clear selection'):
390+
st.session_state.df_selected_from_plotly = pd.DataFrame()
391+
392+
with st.expander('show sessions', expanded=False):
393+
st.dataframe(st.session_state.df_selected_from_plotly)
394+
397395

398396
names = {('session', 'foraging_eff'): 'Foraging efficiency',
399397
('session', 'finished'): 'Finished trials',
@@ -442,11 +440,30 @@ def init():
442440

443441
st.session_state.df_selected_from_plotly = pd.DataFrame()
444442

443+
# Init session states
445444
# add some model fitting params to session
446445
if not 'model_id' in st.session_state:
447446
st.session_state.model_id = 21
448447
selected_id = st.session_state.model_id
449448

449+
st.session_state.draw_type_mapper = {'1. Choice history': ('fitted_choice', # prefix
450+
(0, 0), # location (row_idx, column_idx)
451+
dict(other_patterns=['model_best', 'model_None'])),
452+
'2. Lick times': ('lick_psth',
453+
(1, 0),
454+
{}),
455+
'3. Logistic regression on choice': ('logistic_regression',
456+
(1, 1),
457+
dict(crop=(0, 0, 1200, 2000))),
458+
'4. Win-stay-lose-shift prob.': ('wsls',
459+
(1, 1),
460+
dict(crop=(0, 0, 1200, 600))),
461+
'5. Linear regression on RT': ('linear_regression_rt',
462+
(1, 0),
463+
dict()),
464+
}
465+
466+
# process dfs
450467
df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')
451468
valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)]
452469
to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field]
@@ -458,6 +475,7 @@ def init():
458475

459476
st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()]
460477

478+
461479

462480

463481
def app():
@@ -472,6 +490,9 @@ def app():
472490
st.cache_data.clear()
473491
init()
474492
st.experimental_rerun()
493+
494+
st.markdown('---')
495+
st.write('Han Hou @ 2023\nv1.0.0')
475496

476497

477498
with st.container():
@@ -496,17 +517,17 @@ def app():
496517

497518
st.session_state.aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered)
498519

499-
chosen_id = stx.tab_bar(data=[
500-
stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"),
501-
stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"),
502-
], default="tab1")
520+
# chosen_id = stx.tab_bar(data=[
521+
# stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"),
522+
# stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"),
523+
# ], default="tab1")
524+
chosen_id = "tab1"
503525

504526
placeholder = st.container()
505527

506528
if chosen_id == "tab1":
507529
with placeholder:
508530
df_selected_from_plotly = population_analysis()
509-
st.markdown('##### Select session(s) from the plots above to draw')
510531

511532
if len(st.session_state.df_selected_from_plotly):
512533
draw_session_plots(st.session_state.df_selected_from_plotly)
@@ -519,6 +540,7 @@ def app():
519540
with placeholder:
520541
selected_keys_from_aggrid = st.session_state.aggrid_outputs['selected_rows']
521542
st.markdown('##### Select session(s) from the table above to draw')
543+
session_plot_settings()
522544
draw_session_plots(selected_keys_from_aggrid)
523545

524546
# st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)

code/streamlit_util.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,4 @@ def add_session_filter():
274274
with st.expander("Behavioral session filter", expanded=True):
275275
st.session_state.df_session_filtered = filter_dataframe(df=st.session_state.df['sessions'])
276276
st.markdown(f"{len(st.session_state.df_session_filtered)} sessions filtered (use_s3 = {st.session_state.use_s3})")
277-
278-
with st.expander(f'{len(st.session_state.df_selected_from_plotly)} sessions selected from plotly', expanded=True):
279-
if st.button('clear selection'):
280-
st.session_state.df_selected_from_plotly = pd.DataFrame()
281-
282-
with st.expander('show sessions', expanded=False):
283-
st.dataframe(st.session_state.df_selected_from_plotly)
277+

0 commit comments

Comments
 (0)