Skip to content

Commit cc35ead

Browse files
authored
Merge pull request #108 from AllenNeuralDynamics/han_upgrade_streamlit_1.41
Streamlit 1.41 upgrade and performance enhancements
2 parents 1a06f45 + 24ca965 commit cc35ead

File tree

10 files changed

+373
-265
lines changed

10 files changed

+373
-265
lines changed

code/Home.py

Lines changed: 132 additions & 126 deletions
Large diffs are not rendered by default.

code/pages/0_Data inventory.py

Lines changed: 95 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy as np
1111
import plotly.graph_objects as go
1212
from plotly.subplots import make_subplots
13-
from streamlit_plotly_events import plotly_events
13+
import plotly.io as pio
14+
pio.json.config.default_engine = "orjson"
1415

1516
import time
1617
import streamlit_nested_layout
@@ -22,6 +23,7 @@
2223
)
2324
from util.reformat import formatting_metadata_df
2425
from util.aws_s3 import load_raw_sessions_on_VAST
26+
from util.settings import override_plotly_theme
2527
from Home import init
2628

2729

@@ -48,15 +50,31 @@
4850
)
4951

5052
# Load QUERY_PRESET from json
51-
with open("data_inventory_QUERY_PRESET.json", "r") as f:
52-
QUERY_PRESET = json.load(f)
53+
@st.cache_data()
54+
def load_presets():
55+
with open("data_inventory_QUERY_PRESET.json", "r") as f:
56+
QUERY_PRESET = json.load(f)
57+
58+
with open("data_inventory_VENN_PRESET.json", "r") as f:
59+
VENN_PRESET = json.load(f)
60+
return QUERY_PRESET, VENN_PRESET
61+
62+
QUERY_PRESET, VENN_PRESET = load_presets()
5363

5464
META_COLUMNS = [
5565
"Han_temp_pipeline (bpod)",
5666
"Han_temp_pipeline (bonsai)",
5767
"VAST_raw_data_on_VAST",
5868
] + [query["alias"] for query in QUERY_PRESET]
5969

70+
X_BIN_SIZE_MAPPER = { # For plotly histogram xbins
71+
"Daily": 1000*3600*24, # Milliseconds
72+
"Weekly": 1000*3600*24*7, # Milliseconds
73+
"Monthly": "M1",
74+
"Quarterly": "M4",
75+
}
76+
77+
6078
@st.cache_data(ttl=3600*12)
6179
def merge_queried_dfs(dfs, queries_to_merge):
6280
# Combine queried dfs using df_unique_mouse_date (on index "subject_id", "session_date" only)
@@ -235,7 +253,7 @@ def count_true_values(df, time_period, column):
235253
rows=len(columns),
236254
cols=1,
237255
shared_xaxes=True,
238-
vertical_spacing=0.05,
256+
vertical_spacing=0.1,
239257
subplot_titles=columns,
240258
)
241259

@@ -261,7 +279,7 @@ def count_true_values(df, time_period, column):
261279

262280
# Updating layout
263281
fig.update_layout(
264-
height=200 * len(columns),
282+
height=250 * len(columns),
265283
showlegend=False,
266284
title=f"{time_period} counts",
267285
)
@@ -270,7 +288,7 @@ def count_true_values(df, time_period, column):
270288
for i, column in enumerate(columns):
271289
fig.add_trace(go.Histogram(
272290
x=df[df[column]==True]["session_date"],
273-
xbins=dict(size="M1"), # Only monthly bins look good
291+
xbins=dict(size=X_BIN_SIZE_MAPPER[time_period]), # Only monthly bins look good
274292
name=column,
275293
marker_color=colors[i],
276294
opacity=0.75
@@ -281,18 +299,18 @@ def count_true_values(df, time_period, column):
281299
height=500,
282300
bargap=0.05, # Gap between bars of adjacent locations
283301
bargroupgap=0.1, # Gap between bars of the same location
284-
barmode='group', # Grouped style
302+
barmode="group", # Grouped style
285303
showlegend=True,
304+
title="Monthly counts",
286305
legend=dict(
287306
orientation="h", # Horizontal legend
288307
y=-0.2, # Position below the plot
289308
x=0.5, # Center the legend
290309
xanchor="center", # Anchor the legend's x position
291-
yanchor="top" # Anchor the legend's y position
310+
yanchor="top", # Anchor the legend's y position
292311
),
293-
title="Monthly counts"
294312
)
295-
313+
296314
return fig
297315

298316
def app():
@@ -413,88 +431,82 @@ def app():
413431
)
414432

415433
# --- Venn diagram from presets ---
416-
with open("data_inventory_VENN_PRESET.json", "r") as f:
417-
VENN_PRESET = json.load(f)
418-
419434
if VENN_PRESET:
435+
add_venn_diagrms(df_merged)
436+
437+
@st.fragment
438+
def add_venn_diagrms(df_merged):
439+
440+
cols = st.columns([2, 1])
441+
cols[0].markdown("## Venn diagrams from presets")
442+
with cols[1].expander("Time view settings", expanded=True):
443+
cols_1 = st.columns([1, 1])
444+
if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True)
445+
if_sync_y_limits = cols_1[0].checkbox(
446+
"Sync Y limits", value=True, disabled=not if_separate_plots
447+
)
448+
time_period = cols_1[1].selectbox(
449+
"Bin size",
450+
["Daily", "Weekly", "Monthly", "Quarterly"],
451+
index=1,
452+
)
420453

421-
cols = st.columns([2, 1])
422-
cols[0].markdown("## Venn diagrams from presets")
423-
with cols[1].expander("Time view settings", expanded=True):
424-
cols_1 = st.columns([1, 1])
425-
if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True)
426-
if_sync_y_limits = cols_1[0].checkbox(
427-
"Sync Y limits", value=True, disabled=not if_separate_plots
454+
for i_venn, venn_preset in enumerate(VENN_PRESET):
455+
# -- Venn diagrams --
456+
st.markdown(f"### ({i_venn+1}). {venn_preset['name']}")
457+
fig, notes = generate_venn(
458+
df_merged,
459+
venn_preset
428460
)
429-
time_period = cols_1[1].selectbox(
430-
"Bin size",
431-
["Daily", "Weekly", "Monthly", "Quarterly"],
432-
index=1,
433-
disabled=not if_separate_plots,
461+
for note in notes:
462+
st.markdown(note)
463+
464+
cols = st.columns([1, 1])
465+
with cols[0]:
466+
st.pyplot(fig, use_container_width=True)
467+
468+
# -- Show and download df for this Venn --
469+
circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]]
470+
# Show histogram over time for the columns and patches in preset
471+
df_this_preset = df_merged[circle_columns]
472+
# Filter out rows that have at least one True in this Venn
473+
df_this_preset = df_this_preset[df_this_preset.any(axis=1)]
474+
475+
# Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
476+
for patch_setting in venn_preset.get("patch_settings", []):
477+
idx = _filter_df_by_patch_ids(
478+
df_this_preset[circle_columns],
479+
patch_setting["patch_ids"]
434480
)
481+
df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True
435482

436-
for i_venn, venn_preset in enumerate(VENN_PRESET):
437-
# -- Venn diagrams --
438-
st.markdown(f"### ({i_venn+1}). {venn_preset['name']}")
439-
fig, notes = generate_venn(
440-
df_merged,
441-
venn_preset
442-
)
443-
for note in notes:
444-
st.markdown(note)
445-
446-
cols = st.columns([1, 1])
447-
with cols[0]:
448-
st.pyplot(fig, use_container_width=True)
449-
450-
# -- Show and download df for this Venn --
451-
circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]]
452-
# Show histogram over time for the columns and patches in preset
453-
df_this_preset = df_merged[circle_columns]
454-
# Filter out rows that have at least one True in this Venn
455-
df_this_preset = df_this_preset[df_this_preset.any(axis=1)]
456-
457-
# Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
458-
for patch_setting in venn_preset.get("patch_settings", []):
459-
idx = _filter_df_by_patch_ids(
460-
df_this_preset[circle_columns],
461-
patch_setting["patch_ids"]
462-
)
463-
df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True
483+
# Join in other extra columns
484+
df_this_preset = df_this_preset.join(
485+
df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left"
486+
)
464487

465-
# Join in other extra columns
466-
df_this_preset = df_this_preset.join(
467-
df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left"
488+
with cols[0]:
489+
download_df(
490+
df_this_preset,
491+
label="Download as CSV for this Venn diagram",
492+
file_name=f"df_{venn_preset['name']}.csv",
468493
)
494+
with st.expander(f"Show dataframe, n = {len(df_this_preset)}"):
495+
st.write(df_this_preset)
469496

470-
with cols[0]:
471-
download_df(
472-
df_this_preset,
473-
label="Download as CSV for this Venn diagram",
474-
file_name=f"df_{venn_preset['name']}.csv",
475-
)
476-
with st.expander(f"Show dataframe, n = {len(df_this_preset)}"):
477-
st.write(df_this_preset)
478-
479-
with cols[1]:
480-
# -- Show histogram over time --
481-
fig = plot_histogram_over_time(
482-
df=df_this_preset.reset_index(),
483-
venn_preset=venn_preset,
484-
time_period=time_period,
485-
if_sync_y_limits=if_sync_y_limits,
486-
if_separate_plots=if_separate_plots,
487-
)
488-
plotly_events(
489-
fig,
490-
click_event=False,
491-
hover_event=False,
492-
select_event=False,
493-
override_height=fig.layout.height * 1.1,
494-
override_width=fig.layout.width,
495-
)
497+
with cols[1]:
498+
# -- Show histogram over time --
499+
fig = plot_histogram_over_time(
500+
df=df_this_preset.reset_index(),
501+
venn_preset=venn_preset,
502+
time_period=time_period,
503+
if_sync_y_limits=if_sync_y_limits,
504+
if_separate_plots=if_separate_plots,
505+
)
506+
override_plotly_theme(fig, font_size_scale=0.9)
507+
st.plotly_chart(fig, use_container_width=True)
496508

497-
st.markdown("---")
509+
st.markdown("---")
498510

499511
# --- User-defined Venn diagram ---
500512
# Multiselect for selecting queries up to three

code/pages/1_Basic behavior analysis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from plotly.subplots import make_subplots
99
from sklearn.decomposition import PCA
1010
from sklearn.preprocessing import StandardScaler
11-
from streamlit_plotly_events import plotly_events
1211
from util.aws_s3 import load_data
1312
from util.streamlit import add_session_filter, data_selector, add_footnote
1413
from scipy.stats import gaussian_kde

code/util/aws_s3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_pat
9494
_f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
9595
output_format='PNG',
9696
caption=f_name.split('/')[-1] if caption and f_name else '',
97-
use_column_width='always',
97+
use_container_width='always',
9898
**kwargs)
9999

100100
return img

code/util/foraging_plotly.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import plotly.express as px
33
import plotly.graph_objs as go
4+
import plotly.io as pio
5+
pio.json.config.default_engine = "orjson"
46

57

68
def moving_average(a, n=3) :

code/util/plot_autotrain_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import numpy as np
44
import pandas as pd
55
import plotly.graph_objects as go
6+
import plotly.io as pio
7+
pio.json.config.default_engine = "orjson"
8+
69
import streamlit as st
710
from aind_auto_train.plot.curriculum import get_stage_color_mapper
811
from aind_auto_train.schema.curriculum import TrainingStage

code/util/settings.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import plotly.io as pio
2+
pio.json.config.default_engine = "orjson"
3+
14
# Setting up layout for each session
25
draw_type_layout_definition = [
36
[1], # columns in the first row
@@ -30,6 +33,82 @@
3033
}
3134

3235
# For quick preview
33-
draw_types_quick_preview = [
34-
'1. Choice history',
35-
'3. Logistic regression (Su2022)']
36+
draw_types_quick_preview = ["1. Choice history", "3. Logistic regression (Su2022)"]
37+
38+
39+
# For plotly styling
40+
PLOTLY_FIG_DEFAULT = dict(
41+
font_family="Arial",
42+
legend_font_color='black',
43+
)
44+
PLOTLY_AXIS_DEFAULT = dict(
45+
showline=True,
46+
linewidth=2,
47+
linecolor="black",
48+
showgrid=True,
49+
gridcolor="lightgray",
50+
griddash="solid",
51+
minor_showgrid=False,
52+
minor_gridcolor="lightgray",
53+
minor_griddash="solid",
54+
zeroline=True,
55+
ticks="outside",
56+
tickcolor="black",
57+
ticklen=7,
58+
tickwidth=2,
59+
ticksuffix=" ",
60+
tickfont=dict(
61+
family="Arial",
62+
color="black",
63+
),
64+
)
65+
66+
def override_plotly_theme(
67+
fig,
68+
theme="simple_white",
69+
fig_specs=PLOTLY_FIG_DEFAULT,
70+
axis_specs=PLOTLY_AXIS_DEFAULT,
71+
font_size_scale=1.0,
72+
):
73+
"""
74+
Fix the problem that simply using fig.update_layout(template=theme) doesn't work with st.plotly_chart.
75+
I have to use update_layout to explicitly set the theme.
76+
"""
77+
78+
dict_plotly_template = pio.templates[theme].layout.to_plotly_json()
79+
fig.update_layout(**dict_plotly_template) # First apply the plotly official theme
80+
81+
# Apply settings to all x-axes
82+
for axis in fig.layout:
83+
if axis.startswith('xaxis') or axis.startswith('yaxis'):
84+
fig.layout[axis].update(axis_specs)
85+
fig.layout[axis].update(
86+
tickfont_size=22 * font_size_scale,
87+
title_font_size=22 * font_size_scale,
88+
)
89+
if axis.startswith("yaxis"):
90+
fig.layout[axis].update(title_standoff=10 * font_size_scale)
91+
92+
fig.update_layout(**fig_specs) # Apply settings to the entire figure
93+
94+
# Customize the font of subplot titles
95+
for annotation in fig['layout']['annotations']:
96+
annotation['font'] = dict(
97+
family="Arial", # Font family
98+
size=20 * font_size_scale, # Font size
99+
color="black" # Font color
100+
)
101+
102+
# Figure-level settings
103+
fig.update_layout(
104+
font_size=22 * font_size_scale,
105+
hoverlabel_font_size=17 * font_size_scale,
106+
legend_font_size=17 * font_size_scale,
107+
margin=dict(
108+
l=130 * font_size_scale,
109+
r=50 * font_size_scale,
110+
b=130 * font_size_scale,
111+
t=100 * font_size_scale,
112+
),
113+
)
114+
return

0 commit comments

Comments
 (0)