2424from util .aws_s3 import (draw_session_plots_quick_preview ,
2525 load_data ,
2626 load_auto_train ,
27+ load_mouse_PI_mapping ,
2728 show_debug_info ,
2829 show_session_level_img_by_key_and_prefix )
2930from util .fetch_data_docDB import load_data_from_docDB
4243 slider_wrapper_for_url_query ,
4344 sync_session_state_to_URL ,
4445 sync_URL_to_session_state )
46+ from util .reformat import get_data_source
4547
4648try :
4749 st .set_page_config (layout = "wide" ,
5658 pass
5759
5860
59- def _user_name_mapper (user_name ):
60- user_mapper = { # tuple of key words --> user name
61- ('Avalon' ,): 'Avalon Amaya' ,
62- ('Ella' ,): 'Ella Hilton' ,
63- ('Katrina' ,): 'Katrina Nguyen' ,
64- ('Lucas' ,): 'Lucas Kinsey' ,
65- ('Travis' ,): 'Travis Ramirez' ,
66- ('Xinxin' , 'the ghost' ): 'Xinxin Yin' ,
67- }
68- for key_words , name in user_mapper .items ():
69- for key_word in key_words :
70- if key_word in user_name :
71- return name
61+ def _trainer_mapper (trainer ):
62+ user_mapper = {
63+ 'Avalon Amaya' : ['Avalon' ],
64+ 'Ella Hilton' : ['Ella' ],
65+ 'Katrina Nguyen' : ['Katrina' ],
66+ 'Lucas Kinsey' : ['Lucas' ],
67+ 'Travis Ramirez' : ['Travis' ],
68+ 'Xinxin Yin' : ['Xinxin' , 'the ghost' ],
69+ 'Bowen Tan' : ['Bowen' ],
70+ 'Henry Loeffler' : ['Henry Loeffer' ],
71+ 'Margaret Lee' : ['margaret lee' ],
72+ 'Madeline Tom' : ['Madseline Tom' ],
73+ }
74+ for canonical_name , alias in user_mapper .items ():
75+ for key_word in alias :
76+ if key_word in trainer :
77+ return canonical_name
7278 else :
73- return user_name
79+ return trainer
7480
7581
7682@st .cache_resource (ttl = 24 * 3600 )
@@ -104,8 +110,8 @@ def draw_session_plots(df_to_draw_session):
104110 except :
105111 date_str = key ["session_date" ].split ("T" )[0 ]
106112
107- st .markdown (f'''<h5 style='text-align: center; color: orange;'>{ key ["h2o " ]} , Session { int (key ["session" ])} , { date_str } '''
108- f'''({ key ["user_name " ]} @{ key ["data_source" ]} )''' ,
113+ st .markdown (f'''<h5 style='text-align: center; color: orange;'>{ key ["subject_id " ]} ( { key [ "PI" ] } ) , Session { int (key ["session" ])} , { date_str } '''
114+ f'''({ key ["trainer " ]} @{ key ["data_source" ]} )''' ,
109115 unsafe_allow_html = True )
110116 if len (st .session_state .session_plot_selected_draw_types ) > 1 : # more than one types, use the pre-defined layout
111117 for row , column_setting in enumerate (draw_type_layout_definition ):
@@ -280,28 +286,26 @@ def plot_x_y_session():
280286 if len (df_selected_from_plotly ) == 1 :
281287 with cols [1 ]:
282288 draw_session_plots_quick_preview (df_selected_from_plotly )
283-
284289 return df_selected_from_plotly , cols
285290
286291
287292def show_curriculums ():
288293 pass
289294
290-
291295
292296# ------- Layout starts here -------- #
293297def init (if_load_bpod_data_override = None , if_load_docDB_override = None ):
294-
298+
295299 # Clear specific session state and all filters
296300 for key in st .session_state :
297301 if key in ['selected_draw_types' ] or '_changed' in key :
298302 del st .session_state [key ]
299-
303+
300304 df = load_data (['sessions' ], data_source = 'bonsai' )
301-
305+
302306 if not len (df ):
303307 return False
304-
308+
305309 # --- Perform any data source-dependent preprocessing here ---
306310 # Because sync_URL_to_session_state() needs df to be loaded (for dynamic column filtering),
307311 # 'if_load_bpod_sessions' has not been synced from URL to session state yet.
@@ -312,77 +316,66 @@ def init(if_load_bpod_data_override=None, if_load_docDB_override=None):
312316 else st .session_state .if_load_bpod_sessions
313317 if 'if_load_bpod_sessions' in st .session_state
314318 else False )
315-
319+
316320 st .session_state .bpod_loaded = False
317321 if _if_load_bpod :
318322 df_bpod = load_data (['sessions' ], data_source = 'bpod' )
319323 st .session_state .bpod_loaded = True
320-
321- # For historial reason, the suffix of df['sessions_bonsai '] just mean the data of the Home.py page
322- df ['sessions_bonsai ' ] = pd .concat ([df ['sessions_bonsai ' ], df_bpod ['sessions_bonsai ' ]], axis = 0 )
323-
324+
325+ # For historial reason, the suffix of df['sessions_main '] just mean the data of the Home.py page
326+ df ['sessions_main ' ] = pd .concat ([df ['sessions_main ' ], df_bpod ['sessions_main ' ]], axis = 0 )
327+
324328 st .session_state .df = df
325329 for source in ["dataframe" , "plotly" ]:
326- st .session_state [f'df_selected_from_{ source } ' ] = pd .DataFrame (columns = ['h2o ' , 'session' ])
327-
330+ st .session_state [f'df_selected_from_{ source } ' ] = pd .DataFrame (columns = ['subject_id ' , 'session' ])
331+
328332 # Load autotrain
329333 auto_train_manager , curriculum_manager = load_auto_train ()
330334 st .session_state .auto_train_manager = auto_train_manager
331335 st .session_state .curriculum_manager = curriculum_manager
332-
336+
333337 # Some ad-hoc modifications on df_sessions
334- _df = st .session_state .df ['sessions_bonsai ' ] # temporary df alias
335-
338+ _df = st .session_state .df ['sessions_main ' ] # temporary df alias
339+
336340 _df .columns = _df .columns .get_level_values (1 )
337341 _df .sort_values (['session_start_time' ], ascending = False , inplace = True )
338342 _df ['session_start_time' ] = _df ['session_start_time' ].astype (str ) # Turn to string
339343 _df = _df .reset_index ().query ('subject_id != "0"' )
340-
344+
341345 # Handle mouse and user name
342346 if 'bpod_backup_h2o' in _df .columns :
343- _df ['h2o ' ] = np .where (_df ['bpod_backup_h2o' ].notnull (), _df ['bpod_backup_h2o' ], _df ['subject_id' ])
344- _df ['user_name ' ] = np .where (_df ['bpod_backup_user_name' ].notnull (), _df ['bpod_backup_user_name' ], _df ['user_name ' ])
347+ _df ['subject_alias ' ] = np .where (_df ['bpod_backup_h2o' ].notnull (), _df ['bpod_backup_h2o' ], _df ['subject_id' ])
348+ _df ['trainer ' ] = np .where (_df ['bpod_backup_user_name' ].notnull (), _df ['bpod_backup_user_name' ], _df ['trainer ' ])
345349 else :
346- _df ['h2o' ] = _df ['subject_id' ]
347-
348-
349- def _get_data_source (rig ):
350- """From rig string, return "{institute}_{rig_type}_{room}_{hardware}"
351- """
352- institute = 'Janelia' if ('bpod' in rig ) and not ('AIND' in rig ) else 'AIND'
353- hardware = 'bpod' if ('bpod' in rig ) else 'bonsai'
354- rig_type = 'ephys' if ('ephys' in rig .lower ()) else 'training'
355-
356- # This is a mess...
357- if institute == 'Janelia' :
358- room = 'NA'
359- elif 'Ephys-Han' in rig :
360- room = '321'
361- elif hardware == 'bpod' :
362- room = '347'
363- elif '447' in rig :
364- room = '447'
365- elif '446' in rig :
366- room = '446'
367- elif '323' in rig :
368- room = '323'
369- elif rig_type == 'ephys' :
370- room = '323'
371- else :
372- room = '447'
373- return institute , rig_type , room , hardware , '_' .join ([institute , rig_type , room , hardware ])
374-
375- # Add data source (Room + Hardware etc)
376- _df [['institute' , 'rig_type' , 'room' , 'hardware' , 'data_source' ]] = _df ['rig' ].apply (lambda x : pd .Series (_get_data_source (x )))
350+ _df ['subject_alias' ] = _df ['subject_id' ]
351+
352+ # map trainer
353+ _df ['trainer' ] = _df ['trainer' ].apply (_trainer_mapper )
354+
355+ # Merge in PI name
356+ df_mouse_pi_mapping = load_mouse_PI_mapping ()
357+ _df = _df .merge (df_mouse_pi_mapping , how = 'left' , on = 'subject_id' ) # Merge in PI name
358+ _df .loc [_df ["PI" ].isnull (), "PI" ] = _df .loc [
359+ _df ["PI" ].isnull () &
360+ (_df ["trainer" ].isin (_df ["PI" ]) | _df ["trainer" ].isin (["Han Hou" , "Marton Rozsa" ])),
361+ "trainer"
362+ ] # Fill in PI with trainer if PI is missing and the trainer was ever a PI
377363
364+
365+ # Add data source (Room + Hardware etc)
366+ _df [['institute' , 'rig_type' , 'room' , 'hardware' , 'data_source' ]] = _df ['rig' ].apply (lambda x : pd .Series (get_data_source (x )))
367+
378368 # Handle session number
379369 _df .dropna (subset = ['session' ], inplace = True ) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now)
380370 _df .drop (_df .query ('session < 1' ).index , inplace = True )
381-
371+
382372 # Remove invalid subject_id
383373 _df = _df [(999999 > _df ["subject_id" ].astype (int ))
384374 & (_df ["subject_id" ].astype (int ) > 300000 )]
385375
376+ # Remove zero finished trials
377+ _df = _df [_df ['finished_trials' ] > 0 ]
378+
386379 # Remove abnormal values
387380 _df .loc [_df ['weight_after' ] > 100 ,
388381 ['weight_after' , 'weight_after_ratio' , 'water_in_session_total' , 'water_after_session' , 'water_day_total' ]
@@ -393,35 +386,32 @@ def _get_data_source(rig):
393386
394387 _df .loc [(_df ['duration_iti_median' ] < 0 ) | (_df ['duration_iti_mean' ] < 0 ),
395388 ['duration_iti_median' , 'duration_iti_mean' , 'duration_iti_std' , 'duration_iti_min' , 'duration_iti_max' ]] = np .nan
396-
389+
397390 _df .loc [_df ['invalid_lick_ratio' ] < 0 ,
398391 ['invalid_lick_ratio' ]]= np .nan
399-
392+
400393 # # add something else
401394 # add abs(bais) to all terms that have 'bias' in name
402395 for col in _df .columns :
403396 if 'bias' in col :
404397 _df [f'abs({ col } )' ] = np .abs (_df [col ])
405-
398+
406399 # # delta weight
407400 # diff_relative_weight_next_day = _df.set_index(
408- # ['session']).sort_values('session', ascending=True).groupby('h2o ').apply(
401+ # ['session']).sort_values('session', ascending=True).groupby('subject_id ').apply(
409402 # lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day")
410-
403+
411404 # weekday
412405 _df .session_date = pd .to_datetime (_df .session_date )
413406 _df ['weekday' ] = _df .session_date .dt .dayofweek + 1
414-
415- # map user_name
416- _df ['user_name' ] = _df ['user_name' ].apply (_user_name_mapper )
417-
407+
418408 # trial stats
419409 _df ['avg_trial_length_in_seconds' ] = _df ['session_run_time_in_min' ] / _df ['total_trials_with_autowater' ] * 60
420-
410+
421411 # last day's total water
422- _df ['water_day_total_last_session' ] = _df .groupby ('h2o ' )['water_day_total' ].shift (1 )
423- _df ['water_after_session_last_session' ] = _df .groupby ('h2o ' )['water_after_session' ].shift (1 )
424-
412+ _df ['water_day_total_last_session' ] = _df .groupby ('subject_id ' )['water_day_total' ].shift (1 )
413+ _df ['water_after_session_last_session' ] = _df .groupby ('subject_id ' )['water_after_session' ].shift (1 )
414+
425415 # fill nan for autotrain fields
426416 filled_values = {'curriculum_name' : 'None' ,
427417 'curriculum_version' : 'None' ,
@@ -432,7 +422,7 @@ def _get_data_source(rig):
432422 'if_overriden_by_trainer' : False ,
433423 }
434424 _df .fillna (filled_values , inplace = True )
435-
425+
436426 # foraging performance = foraing_eff * finished_rate
437427 if 'foraging_performance' not in _df .columns :
438428 _df ['foraging_performance' ] = \
@@ -444,20 +434,19 @@ def _get_data_source(rig):
444434
445435 # drop 'bpod_backup_' columns
446436 _df .drop ([col for col in _df .columns if 'bpod_backup_' in col ], axis = 1 , inplace = True )
447-
437+
448438 # fix if_overriden_by_trainer
449439 _df ['if_overriden_by_trainer' ] = _df ['if_overriden_by_trainer' ].astype (bool )
450-
440+
451441 # _df = _df.merge(
452- # diff_relative_weight_next_day, how='left', on=['h2o ', 'session'])
453-
442+ # diff_relative_weight_next_day, how='left', on=['subject_id ', 'session'])
443+
454444 # Recorder columns so that autotrain info is easier to see
455445 first_several_cols = ['subject_id' , 'session_date' , 'nwb_suffix' , 'session' , 'rig' ,
456- 'user_name ' , 'curriculum_name' , 'curriculum_version' , 'current_stage_actual' ,
446+ 'trainer' , 'PI ' , 'curriculum_name' , 'curriculum_version' , 'current_stage_actual' ,
457447 'task' , 'notes' ]
458448 new_order = first_several_cols + [col for col in _df .columns if col not in first_several_cols ]
459449 _df = _df [new_order ]
460-
461450
462451 # --- Load data from docDB ---
463452 if_load_docDb = if_load_docDB_override if if_load_docDB_override is not None else (
@@ -466,10 +455,10 @@ def _get_data_source(rig):
466455 else st .session_state .if_load_docDB
467456 if 'if_load_docDB' in st .session_state
468457 else False )
469-
458+
470459 if if_load_docDb :
471460 _df = merge_in_df_docDB (_df )
472-
461+
473462 # add docDB_status column
474463 _df ["docDB_status" ] = _df .apply (
475464 lambda row : (
@@ -484,15 +473,15 @@ def _get_data_source(rig):
484473 axis = 1 ,
485474 )
486475
487- st .session_state .df ['sessions_bonsai ' ] = _df # Somehow _df loses the reference to the original dataframe
476+ st .session_state .df ['sessions_main ' ] = _df # Somehow _df loses the reference to the original dataframe
488477 st .session_state .session_stats_names = [keys for keys in _df .keys ()]
489478
490479 # Set session state from URL
491480 sync_URL_to_session_state ()
492-
481+
493482 # Establish communication between pygwalker and streamlit
494483 init_streamlit_comm ()
495-
484+
496485 return True
497486
498487def merge_in_df_docDB (_df ):
@@ -536,7 +525,7 @@ def app():
536525 cols = st .columns ([4 , 4 , 4 , 1 ])
537526 cols [0 ].markdown (f'### Filter the sessions on the sidebar\n '
538527 f'##### { len (st .session_state .df_session_filtered )} sessions, '
539- f'{ len (st .session_state .df_session_filtered .h2o .unique ())} mice filtered' )
528+ f'{ len (st .session_state .df_session_filtered .subject_id .unique ())} mice filtered' )
540529 with cols [1 ]:
541530 with st .form (key = 'load_settings' , clear_on_submit = False ):
542531 if_load_bpod_sessions = checkbox_wrapper_for_url_query (
@@ -582,8 +571,8 @@ def app():
582571 )
583572
584573 if len (aggrid_outputs ['selected_rows' ]) \
585- and not set (pd .DataFrame (aggrid_outputs ['selected_rows' ]).set_index (['h2o ' , 'session' ]).index
586- ) == set (st .session_state .df_selected_from_dataframe .set_index (['h2o ' , 'session' ]).index ) \
574+ and not set (pd .DataFrame (aggrid_outputs ['selected_rows' ]).set_index (['subject_id ' , 'session' ]).index
575+ ) == set (st .session_state .df_selected_from_dataframe .set_index (['subject_id ' , 'session' ]).index ) \
587576 and not st .session_state .get ("df_selected_from_dataframe_just_overriden" , False ): # so that if the user just overriden the df_selected_from_dataframe by pressing sidebar button, it won't sync selected rows in the table to session state
588577 st .session_state .df_selected_from_dataframe = pd .DataFrame (aggrid_outputs ['selected_rows' ]) # Use selected in dataframe to update "selected"
589578 st .rerun ()
@@ -780,7 +769,7 @@ def add_main_tabs():
780769
781770if __name__ == "__main__" :
782771 ok = True
783- if 'df' not in st .session_state or 'sessions_bonsai ' not in st .session_state .df .keys ():
772+ if 'df' not in st .session_state or 'sessions_main ' not in st .session_state .df .keys ():
784773 ok = init ()
785774
786775 if ok :
0 commit comments