Skip to content

Commit c8ae95b

Browse files
committed
enable cache for autotrain manager
1 parent c21efa0 commit c8ae95b

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

code/util/plot_autotrain_manager.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,33 @@
88
from aind_auto_train.schema.curriculum import TrainingStage
99

1010

11-
def plot_manager_all_progress(manager: 'AutoTrainManager',
12-
x_axis: ['session', 'date',
13-
'relative_date'] = 'session', # type: ignore
14-
sort_by: ['subject_id', 'first_date',
15-
'last_date', 'progress_to_graduated'] = 'subject_id',
16-
sort_order: ['ascending',
17-
'descending'] = 'descending',
18-
recent_days: int=None,
19-
marker_size=10,
20-
marker_edge_width=2,
21-
highlight_subjects=[],
22-
if_show_fig=True
23-
):
24-
25-
11+
@st.cache_data(ttl=3600 * 24)
12+
def plot_manager_all_progress(
13+
x_axis: ["session", "date", "relative_date"] = "session", # type: ignore
14+
sort_by: [
15+
"subject_id",
16+
"first_date",
17+
"last_date",
18+
"progress_to_graduated",
19+
] = "subject_id",
20+
sort_order: ["ascending", "descending"] = "descending",
21+
recent_days: int = None,
22+
marker_size=10,
23+
marker_edge_width=2,
24+
highlight_subjects=[],
25+
if_show_fig=True,
26+
):
27+
28+
manager = st.session_state.auto_train_manager
29+
2630
# %%
2731
# Set default order
2832
df_manager = manager.df_manager.sort_values(by=['subject_id', 'session'],
2933
ascending=[sort_order == 'ascending', False])
30-
34+
3135
if not len(df_manager):
3236
return None
33-
37+
3438
# Get some additional metadata from the master table
3539
df_tmp_rig_user_name = st.session_state.df['sessions_bonsai'].loc[:, ['subject_id', 'session_date', 'rig', 'user_name']]
3640
df_tmp_rig_user_name.session_date = df_tmp_rig_user_name.session_date.astype(str)
@@ -51,18 +55,18 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
5155
elif sort_by == 'progress_to_graduated':
5256
manager.compute_stats()
5357
df_stats = manager.df_manager_stats
54-
58+
5559
# Sort by 'first_entry' of GRADUATED
5660
subject_ids = df_stats.reset_index().set_index(
5761
'subject_id'
5862
).query(
5963
f'current_stage_actual == "GRADUATED"'
6064
)['first_entry'].sort_values(
6165
ascending=sort_order != 'ascending').index.to_list()
62-
66+
6367
# Append subjects that have not graduated
6468
subject_ids = subject_ids + [s for s in df_manager.subject_id.unique() if s not in subject_ids]
65-
69+
6670
else:
6771
raise ValueError(
6872
f'sort_by must be in {["subject_id", "first_date", "last_date", "progress"]}')
@@ -71,17 +75,17 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
7175
traces = []
7276
for n, subject_id in enumerate(subject_ids):
7377
df_subject = df_manager[df_manager['subject_id'] == subject_id]
74-
78+
7579
# Get stage_color_mapper
7680
stage_color_mapper = get_stage_color_mapper(stage_list=list(TrainingStage.__members__))
77-
81+
7882
# Get h2o if available
7983
if 'h2o' in manager.df_behavior:
8084
h2o = manager.df_behavior[
8185
manager.df_behavior['subject_id'] == subject_id]['h2o'].iloc[0]
8286
else:
8387
h2o = None
84-
88+
8589
df_subject = df_subject.merge(
8690
df_tmp_rig_user_name,
8791
on=['subject_id', 'session_date'], how='left')
@@ -105,11 +109,11 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
105109
else:
106110
raise ValueError(
107111
f"x_axis can only be in ['session', 'date', 'relative_date']")
108-
112+
109113
# Cache x range
110114
xrange_min = x.min() if n == 0 else min(x.min(), xrange_min)
111115
xrange_max = x.max() if n == 0 else max(x.max(), xrange_max)
112-
116+
113117
y = len(subject_ids) - n # Y axis
114118

115119
traces.append(go.Scattergl(
@@ -159,7 +163,7 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
159163
showlegend=False
160164
)
161165
)
162-
166+
163167
# Add "x" for open loop sessions
164168
traces.append(go.Scattergl(
165169
x=x[open_loop_ids],
@@ -197,14 +201,14 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
197201
),
198202
yaxis_range=[-0.5, len(subject_ids) + 1],
199203
)
200-
204+
201205
# Limit x range to recent days if x is "date"
202206
if x_axis == 'date' and recent_days is not None:
203207
# xrange_max = pd.Timestamp.today() # For unknown reasons, using this line will break both plotly_events and new st.plotly_chart callback...
204208
xrange_max = pd.to_datetime(df_manager.session_date).max() + pd.Timedelta(days=1)
205209
xrange_min = xrange_max - pd.Timedelta(days=recent_days)
206210
fig.update_layout(xaxis_range=[xrange_min, xrange_max])
207-
211+
208212
# Highight the selected subject
209213
for n, subject_id in enumerate(subject_ids):
210214
y = len(subject_ids) - n # Y axis
@@ -222,7 +226,6 @@ def plot_manager_all_progress(manager: 'AutoTrainManager',
222226
opacity=0.3,
223227
layer="below"
224228
)
225-
226229

227230
# Show the plot
228231
if if_show_fig:

code/util/streamlit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ def add_auto_train_manager():
791791
recent_weeks = slider_wrapper_for_url_query(cols[5],
792792
label="only recent weeks",
793793
min_value=1,
794-
max_value=26,
794+
max_value=52,
795795
step=1,
796796
key='auto_training_history_recent_weeks',
797797
default=8,
@@ -808,7 +808,6 @@ def add_auto_train_manager():
808808
highlight_subjects = []
809809

810810
fig_auto_train = plot_manager_all_progress(
811-
st.session_state.auto_train_manager,
812811
x_axis=x_axis,
813812
recent_days=recent_weeks*7,
814813
sort_by=sort_by,

0 commit comments

Comments
 (0)