@@ -123,26 +123,7 @@ def draw_session_plots(keys_to_draw_session):
123123 [1.5 , 1 ], # columns in the second row
124124 ]
125125
126- draw_type_mapper = {'1. Choice history' : ('fitted_choice' , # prefix
127- (0 , 0 ), # location (row_idx, column_idx)
128- dict (other_patterns = ['model_best' , 'model_None' ])),
129- '2. Lick times' : ('lick_psth' ,
130- (1 , 0 ),
131- {}),
132- '3. Logistic regression on choice' : ('logistic_regression' ,
133- (1 , 1 ),
134- dict (crop = (0 , 0 , 1200 , 2000 ))),
135- '4. Win-stay-lose-shift prob.' : ('wsls' ,
136- (1 , 1 ),
137- dict (crop = (0 , 0 , 1200 , 600 ))),
138- '5. Linear regression on RT' : ('linear_regression_rt' ,
139- (1 , 0 ),
140- dict ()),
141- }
142-
143- cols_option = st .columns ([3 , 0.5 , 1 ])
144- selected_draw_types = cols_option [0 ].multiselect ('Which plot(s) to draw?' , draw_type_mapper .keys (), default = draw_type_mapper .keys ())
145- num_cols = cols_option [1 ].number_input ('Number of columns' , 1 , 10 , 2 )
126+ # cols_option = st.columns([3, 0.5, 1])
146127 container_session_all_in_one = st .container ()
147128
148129 with container_session_all_in_one :
@@ -152,30 +133,30 @@ def draw_session_plots(keys_to_draw_session):
152133 st .write (f'Loading selected { len (keys_to_draw_session )} sessions...' )
153134 my_bar = st .columns ((1 , 7 ))[0 ].progress (0 )
154135
155- major_cols = st .columns ([1 ] * num_cols )
136+ major_cols = st .columns ([1 ] * st . session_state . num_cols )
156137
157138 if not isinstance (keys_to_draw_session , list ): # Turn dataframe to list, if necessary
158139 keys_to_draw_session = keys_to_draw_session .to_dict (orient = 'records' )
159140
160141 for i , key in enumerate (keys_to_draw_session ):
161- this_major_col = major_cols [i % num_cols ]
142+ this_major_col = major_cols [i % st . session_state . num_cols ]
162143
163144 # setting up layout for each session
164145 rows = []
165146 with this_major_col :
166- st .markdown (f'''<h3 style='text-align: center; color: blue ;'>{ key ["h2o" ]} , Session { key ["session" ]} , { key ["session_date" ].split ("T" )[0 ]} ''' ,
147+ st .markdown (f'''<h3 style='text-align: center; color: orange ;'>{ key ["h2o" ]} , Session { key ["session" ]} , { key ["session_date" ].split ("T" )[0 ]} ''' ,
167148 unsafe_allow_html = True )
168- if len (selected_draw_types ) > 1 : # more than one types, use the pre-defined layout
149+ if len (st . session_state . selected_draw_types ) > 1 : # more than one types, use the pre-defined layout
169150 for row , column_setting in enumerate (layout_definition ):
170151 rows .append (this_major_col .columns (column_setting ))
171152 else : # else, put it in the whole column
172153 rows = this_major_col .columns ([1 ])
173154 st .markdown ("---" )
174155
175- for draw_type in draw_type_mapper :
176- if draw_type not in selected_draw_types : continue # To keep the draw order defined by draw_type_mapper
177- prefix , position , setting = draw_type_mapper [draw_type ]
178- this_col = rows [position [0 ]][position [1 ]] if len (selected_draw_types ) > 1 else rows [0 ]
156+ for draw_type in st . session_state . draw_type_mapper :
157+ if draw_type not in st . session_state . selected_draw_types : continue # To keep the draw order defined by st.session_state. draw_type_mapper
158+ prefix , position , setting = st . session_state . draw_type_mapper [draw_type ]
159+ this_col = rows [position [0 ]][position [1 ]] if len (st . session_state . selected_draw_types ) > 1 else rows [0 ]
179160 show_img_by_key_and_prefix (key ,
180161 column = this_col ,
181162 prefix = prefix ,
@@ -350,7 +331,7 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):
350331
351332 fig .update_layout (
352333 # width=1300,
353- height = 850 ,
334+ # height=850,
354335 xaxis_title = x_name ,
355336 yaxis_title = y_name ,
356337 # xaxis_range=[0, min(100, df[x_name].max())],
@@ -361,9 +342,14 @@ def _add_agg(df_this, x_name, y_name, group, aggr_method, col):
361342 )
362343
363344 # st.plotly_chart(fig)
364- selected_sessions_from_plot = plotly_events (fig , click_event = True , hover_event = False , select_event = True , override_height = 870 )
345+ selected_sessions_from_plot = plotly_events (fig , click_event = True , hover_event = False , select_event = True , override_height = 870 , override_width = 1400 )
365346
366347 return selected_sessions_from_plot
348+
349+ def session_plot_settings ():
350+ st .session_state .selected_draw_types = st .multiselect ('Which plot(s) to draw?' , st .session_state .draw_type_mapper .keys (), default = st .session_state .draw_type_mapper .keys ())
351+ st .session_state .num_cols = st .columns ([1 , 3 ])[0 ].number_input ('Number of columns' , 1 , 10 , 2 )
352+
367353
368354
369355def population_analysis ():
@@ -393,7 +379,19 @@ def population_analysis():
393379
394380 smooth_factor = s_cols [0 ].slider ('Smooth factor' , 1 , 20 , 5 , disabled = not ((if_aggr_each_group and aggr_method_group == 'lowess' )
395381 or (if_aggr_all and aggr_method_all == 'lowess' )))
396-
382+ for i in range (7 ): st .write ('\n ' )
383+
384+ st .markdown ("***" )
385+ st .markdown ('##### Click or box/lasso select session(s) from the plots to draw 👉' )
386+ session_plot_settings ()
387+
388+ with st .expander (f'{ len (st .session_state .df_selected_from_plotly )} sessions selected from plotly' , expanded = False ):
389+ if st .button ('clear selection' ):
390+ st .session_state .df_selected_from_plotly = pd .DataFrame ()
391+
392+ with st .expander ('show sessions' , expanded = False ):
393+ st .dataframe (st .session_state .df_selected_from_plotly )
394+
397395
398396 names = {('session' , 'foraging_eff' ): 'Foraging efficiency' ,
399397 ('session' , 'finished' ): 'Finished trials' ,
@@ -442,11 +440,30 @@ def init():
442440
443441 st .session_state .df_selected_from_plotly = pd .DataFrame ()
444442
443+ # Init session states
445444 # add some model fitting params to session
446445 if not 'model_id' in st .session_state :
447446 st .session_state .model_id = 21
448447 selected_id = st .session_state .model_id
449448
449+ st .session_state .draw_type_mapper = {'1. Choice history' : ('fitted_choice' , # prefix
450+ (0 , 0 ), # location (row_idx, column_idx)
451+ dict (other_patterns = ['model_best' , 'model_None' ])),
452+ '2. Lick times' : ('lick_psth' ,
453+ (1 , 0 ),
454+ {}),
455+ '3. Logistic regression on choice' : ('logistic_regression' ,
456+ (1 , 1 ),
457+ dict (crop = (0 , 0 , 1200 , 2000 ))),
458+ '4. Win-stay-lose-shift prob.' : ('wsls' ,
459+ (1 , 1 ),
460+ dict (crop = (0 , 0 , 1200 , 600 ))),
461+ '5. Linear regression on RT' : ('linear_regression_rt' ,
462+ (1 , 0 ),
463+ dict ()),
464+ }
465+
466+ # process dfs
450467 df_this_model = st .session_state .df ['model_fitting_params' ].query (f'model_id == { selected_id } ' )
451468 valid_field = df_this_model .columns [~ np .all (~ df_this_model .notna (), axis = 0 )]
452469 to_add_model = st .session_state .df ['model_fitting_params' ].query (f'model_id == { selected_id } ' )[valid_field ]
@@ -458,6 +475,7 @@ def init():
458475
459476 st .session_state .session_stats_names = [keys for keys in st .session_state .df ['sessions' ].keys ()]
460477
478+
461479
462480
463481def app ():
@@ -472,6 +490,9 @@ def app():
472490 st .cache_data .clear ()
473491 init ()
474492 st .experimental_rerun ()
493+
494+ st .markdown ('---' )
495+ st .write ('Han Hou @ 2023\n v1.0.0' )
475496
476497
477498 with st .container ():
@@ -496,17 +517,17 @@ def app():
496517
497518 st .session_state .aggrid_outputs = aggrid_interactive_table_session (df = st .session_state .df_session_filtered )
498519
499- chosen_id = stx .tab_bar (data = [
500- stx .TabBarItemData (id = "tab1" , title = "📈Training summary" , description = "Plot training summary" ),
501- stx .TabBarItemData (id = "tab2" , title = "📚Session inspection" , description = "Generate plots for each session" ),
502- ], default = "tab1" )
520+ # chosen_id = stx.tab_bar(data=[
521+ # stx.TabBarItemData(id="tab1", title="📈Training summary", description="Plot training summary"),
522+ # stx.TabBarItemData(id="tab2", title="📚Session inspection", description="Generate plots for each session"),
523+ # ], default="tab1")
524+ chosen_id = "tab1"
503525
504526 placeholder = st .container ()
505527
506528 if chosen_id == "tab1" :
507529 with placeholder :
508530 df_selected_from_plotly = population_analysis ()
509- st .markdown ('##### Select session(s) from the plots above to draw' )
510531
511532 if len (st .session_state .df_selected_from_plotly ):
512533 draw_session_plots (st .session_state .df_selected_from_plotly )
@@ -519,6 +540,7 @@ def app():
519540 with placeholder :
520541 selected_keys_from_aggrid = st .session_state .aggrid_outputs ['selected_rows' ]
521542 st .markdown ('##### Select session(s) from the table above to draw' )
543+ session_plot_settings ()
522544 draw_session_plots (selected_keys_from_aggrid )
523545
524546 # st.dataframe(st.session_state.df_session_filtered, use_container_width=True, height=1000)
0 commit comments