88from aind_auto_train .schema .curriculum import TrainingStage
99
1010
11- def plot_manager_all_progress (manager : 'AutoTrainManager' ,
12- x_axis : ['session' , 'date' ,
13- 'relative_date' ] = 'session' , # type: ignore
14- sort_by : ['subject_id' , 'first_date' ,
15- 'last_date' , 'progress_to_graduated' ] = 'subject_id' ,
16- sort_order : ['ascending' ,
17- 'descending' ] = 'descending' ,
18- recent_days : int = None ,
19- marker_size = 10 ,
20- marker_edge_width = 2 ,
21- highlight_subjects = [],
22- if_show_fig = True
23- ):
24-
25-
11+ @st .cache_data (ttl = 3600 * 24 )
12+ def plot_manager_all_progress (
13+ x_axis : ["session" , "date" , "relative_date" ] = "session" , # type: ignore
14+ sort_by : [
15+ "subject_id" ,
16+ "first_date" ,
17+ "last_date" ,
18+ "progress_to_graduated" ,
19+ ] = "subject_id" ,
20+ sort_order : ["ascending" , "descending" ] = "descending" ,
21+ recent_days : int = None ,
22+ marker_size = 10 ,
23+ marker_edge_width = 2 ,
24+ highlight_subjects = [],
25+ if_show_fig = True ,
26+ ):
27+
28+ manager = st .session_state .auto_train_manager
29+
2630 # %%
2731 # Set default order
2832 df_manager = manager .df_manager .sort_values (by = ['subject_id' , 'session' ],
2933 ascending = [sort_order == 'ascending' , False ])
30-
34+
3135 if not len (df_manager ):
3236 return None
33-
37+
3438 # Get some additional metadata from the master table
3539 df_tmp_rig_user_name = st .session_state .df ['sessions_bonsai' ].loc [:, ['subject_id' , 'session_date' , 'rig' , 'user_name' ]]
3640 df_tmp_rig_user_name .session_date = df_tmp_rig_user_name .session_date .astype (str )
@@ -51,18 +55,18 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
5155 elif sort_by == 'progress_to_graduated' :
5256 manager .compute_stats ()
5357 df_stats = manager .df_manager_stats
54-
58+
5559 # Sort by 'first_entry' of GRADUATED
5660 subject_ids = df_stats .reset_index ().set_index (
5761 'subject_id'
5862 ).query (
5963 f'current_stage_actual == "GRADUATED"'
6064 )['first_entry' ].sort_values (
6165 ascending = sort_order != 'ascending' ).index .to_list ()
62-
66+
6367 # Append subjects that have not graduated
6468 subject_ids = subject_ids + [s for s in df_manager .subject_id .unique () if s not in subject_ids ]
65-
69+
6670 else :
6771 raise ValueError (
6872 f'sort_by must be in { ["subject_id" , "first_date" , "last_date" , "progress" ]} ' )
@@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
7175 traces = []
7276 for n , subject_id in enumerate (subject_ids ):
7377 df_subject = df_manager [df_manager ['subject_id' ] == subject_id ]
74-
78+
7579 # Get stage_color_mapper
7680 stage_color_mapper = get_stage_color_mapper (stage_list = list (TrainingStage .__members__ ))
77-
81+
7882 # Get h2o if available
7983 if 'h2o' in manager .df_behavior :
8084 h2o = manager .df_behavior [
8185 manager .df_behavior ['subject_id' ] == subject_id ]['h2o' ].iloc [0 ]
8286 else :
8387 h2o = None
84-
88+
8589 df_subject = df_subject .merge (
8690 df_tmp_rig_user_name ,
8791 on = ['subject_id' , 'session_date' ], how = 'left' )
@@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
105109 else :
106110 raise ValueError (
107111 f"x_axis can only be in ['session', 'date', 'relative_date']" )
108-
112+
109113 # Cache x range
110114 xrange_min = x .min () if n == 0 else min (x .min (), xrange_min )
111115 xrange_max = x .max () if n == 0 else max (x .max (), xrange_max )
112-
116+
113117 y = len (subject_ids ) - n # Y axis
114118
115119 traces .append (go .Scattergl (
@@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
159163 showlegend = False
160164 )
161165 )
162-
166+
163167 # Add "x" for open loop sessions
164168 traces .append (go .Scattergl (
165169 x = x [open_loop_ids ],
@@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
197201 ),
198202 yaxis_range = [- 0.5 , len (subject_ids ) + 1 ],
199203 )
200-
204+
201205 # Limit x range to recent days if x is "date"
202206 if x_axis == 'date' and recent_days is not None :
203207 # xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback...
204208 xrange_max = pd .to_datetime (df_manager .session_date ).max () + pd .Timedelta (days = 1 )
205209 xrange_min = xrange_max - pd .Timedelta (days = recent_days )
206210 fig .update_layout (xaxis_range = [xrange_min , xrange_max ])
207-
211+
208212 # Highight the selected subject
209213 for n , subject_id in enumerate (subject_ids ):
210214 y = len (subject_ids ) - n # Y axis
@@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
222226 opacity = 0.3 ,
223227 layer = "below"
224228 )
225-
226229
227230 # Show the plot
228231 if if_show_fig :
0 commit comments