1010import numpy as np
1111import plotly .graph_objects as go
1212from plotly .subplots import make_subplots
13- from streamlit_plotly_events import plotly_events
13+ import plotly .io as pio
14+ pio .json .config .default_engine = "orjson"
1415
1516import time
1617import streamlit_nested_layout
2223)
2324from util .reformat import formatting_metadata_df
2425from util .aws_s3 import load_raw_sessions_on_VAST
26+ from util .settings import override_plotly_theme
2527from Home import init
2628
2729
4850)
4951
5052# Load QUERY_PRESET from json
51- with open ("data_inventory_QUERY_PRESET.json" , "r" ) as f :
52- QUERY_PRESET = json .load (f )
53+ @st .cache_data ()
54+ def load_presets ():
55+ with open ("data_inventory_QUERY_PRESET.json" , "r" ) as f :
56+ QUERY_PRESET = json .load (f )
57+
58+ with open ("data_inventory_VENN_PRESET.json" , "r" ) as f :
59+ VENN_PRESET = json .load (f )
60+ return QUERY_PRESET , VENN_PRESET
61+
62+ QUERY_PRESET , VENN_PRESET = load_presets ()
5363
5464META_COLUMNS = [
5565 "Han_temp_pipeline (bpod)" ,
5666 "Han_temp_pipeline (bonsai)" ,
5767 "VAST_raw_data_on_VAST" ,
5868] + [query ["alias" ] for query in QUERY_PRESET ]
5969
70+ X_BIN_SIZE_MAPPER = { # For plotly histogram xbins
71+ "Daily" : 1000 * 3600 * 24 , # Milliseconds
72+ "Weekly" : 1000 * 3600 * 24 * 7 , # Milliseconds
73+ "Monthly" : "M1" ,
74+ "Quarterly" : "M4" ,
75+ }
76+
77+
6078@st .cache_data (ttl = 3600 * 12 )
6179def merge_queried_dfs (dfs , queries_to_merge ):
6280 # Combine queried dfs using df_unique_mouse_date (on index "subject_id", "session_date" only)
@@ -235,7 +253,7 @@ def count_true_values(df, time_period, column):
235253 rows = len (columns ),
236254 cols = 1 ,
237255 shared_xaxes = True ,
238- vertical_spacing = 0.05 ,
256+ vertical_spacing = 0.1 ,
239257 subplot_titles = columns ,
240258 )
241259
@@ -261,7 +279,7 @@ def count_true_values(df, time_period, column):
261279
262280 # Updating layout
263281 fig .update_layout (
264- height = 200 * len (columns ),
282+ height = 250 * len (columns ),
265283 showlegend = False ,
266284 title = f"{ time_period } counts" ,
267285 )
@@ -270,7 +288,7 @@ def count_true_values(df, time_period, column):
270288 for i , column in enumerate (columns ):
271289 fig .add_trace (go .Histogram (
272290 x = df [df [column ]== True ]["session_date" ],
273- xbins = dict (size = "M1" ), # Only monthly bins look good
291+ xbins = dict (size = X_BIN_SIZE_MAPPER [ time_period ] ), # Only monthly bins look good
274292 name = column ,
275293 marker_color = colors [i ],
276294 opacity = 0.75
@@ -281,18 +299,18 @@ def count_true_values(df, time_period, column):
281299 height = 500 ,
282300 bargap = 0.05 , # Gap between bars of adjacent locations
283301 bargroupgap = 0.1 , # Gap between bars of the same location
284- barmode = ' group' , # Grouped style
302+ barmode = " group" , # Grouped style
285303 showlegend = True ,
304+ title = "Monthly counts" ,
286305 legend = dict (
287306 orientation = "h" , # Horizontal legend
288307 y = - 0.2 , # Position below the plot
289308 x = 0.5 , # Center the legend
290309 xanchor = "center" , # Anchor the legend's x position
291- yanchor = "top" # Anchor the legend's y position
310+ yanchor = "top" , # Anchor the legend's y position
292311 ),
293- title = "Monthly counts"
294312 )
295-
313+
296314 return fig
297315
298316def app ():
@@ -413,88 +431,82 @@ def app():
413431 )
414432
415433 # --- Venn diagram from presets ---
416- with open ("data_inventory_VENN_PRESET.json" , "r" ) as f :
417- VENN_PRESET = json .load (f )
418-
419434 if VENN_PRESET :
435+ add_venn_diagrms (df_merged )
436+
437+ @st .fragment
438+ def add_venn_diagrms (df_merged ):
439+
440+ cols = st .columns ([2 , 1 ])
441+ cols [0 ].markdown ("## Venn diagrams from presets" )
442+ with cols [1 ].expander ("Time view settings" , expanded = True ):
443+ cols_1 = st .columns ([1 , 1 ])
444+ if_separate_plots = cols_1 [0 ].checkbox ("Separate in subplots" , value = True )
445+ if_sync_y_limits = cols_1 [0 ].checkbox (
446+ "Sync Y limits" , value = True , disabled = not if_separate_plots
447+ )
448+ time_period = cols_1 [1 ].selectbox (
449+ "Bin size" ,
450+ ["Daily" , "Weekly" , "Monthly" , "Quarterly" ],
451+ index = 1 ,
452+ )
420453
421- cols = st .columns ([2 , 1 ])
422- cols [0 ].markdown ("## Venn diagrams from presets" )
423- with cols [1 ].expander ("Time view settings" , expanded = True ):
424- cols_1 = st .columns ([1 , 1 ])
425- if_separate_plots = cols_1 [0 ].checkbox ("Separate in subplots" , value = True )
426- if_sync_y_limits = cols_1 [0 ].checkbox (
427- "Sync Y limits" , value = True , disabled = not if_separate_plots
454+ for i_venn , venn_preset in enumerate (VENN_PRESET ):
455+ # -- Venn diagrams --
456+ st .markdown (f"### ({ i_venn + 1 } ). { venn_preset ['name' ]} " )
457+ fig , notes = generate_venn (
458+ df_merged ,
459+ venn_preset
428460 )
429- time_period = cols_1 [1 ].selectbox (
430- "Bin size" ,
431- ["Daily" , "Weekly" , "Monthly" , "Quarterly" ],
432- index = 1 ,
433- disabled = not if_separate_plots ,
461+ for note in notes :
462+ st .markdown (note )
463+
464+ cols = st .columns ([1 , 1 ])
465+ with cols [0 ]:
466+ st .pyplot (fig , use_container_width = True )
467+
468+ # -- Show and download df for this Venn --
469+ circle_columns = [c_s ["column" ] for c_s in venn_preset ["circle_settings" ]]
470+ # Show histogram over time for the columns and patches in preset
471+ df_this_preset = df_merged [circle_columns ]
472+ # Filter out rows that have at least one True in this Venn
473+ df_this_preset = df_this_preset [df_this_preset .any (axis = 1 )]
474+
475+ # Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
476+ for patch_setting in venn_preset .get ("patch_settings" , []):
477+ idx = _filter_df_by_patch_ids (
478+ df_this_preset [circle_columns ],
479+ patch_setting ["patch_ids" ]
434480 )
481+ df_this_preset .loc [idx , str (patch_setting ["patch_ids" ])] = True
435482
436- for i_venn , venn_preset in enumerate (VENN_PRESET ):
437- # -- Venn diagrams --
438- st .markdown (f"### ({ i_venn + 1 } ). { venn_preset ['name' ]} " )
439- fig , notes = generate_venn (
440- df_merged ,
441- venn_preset
442- )
443- for note in notes :
444- st .markdown (note )
445-
446- cols = st .columns ([1 , 1 ])
447- with cols [0 ]:
448- st .pyplot (fig , use_container_width = True )
449-
450- # -- Show and download df for this Venn --
451- circle_columns = [c_s ["column" ] for c_s in venn_preset ["circle_settings" ]]
452- # Show histogram over time for the columns and patches in preset
453- df_this_preset = df_merged [circle_columns ]
454- # Filter out rows that have at least one True in this Venn
455- df_this_preset = df_this_preset [df_this_preset .any (axis = 1 )]
456-
457- # Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
458- for patch_setting in venn_preset .get ("patch_settings" , []):
459- idx = _filter_df_by_patch_ids (
460- df_this_preset [circle_columns ],
461- patch_setting ["patch_ids" ]
462- )
463- df_this_preset .loc [idx , str (patch_setting ["patch_ids" ])] = True
483+ # Join in other extra columns
484+ df_this_preset = df_this_preset .join (
485+ df_merged [[col for col in df_merged .columns if col not in META_COLUMNS ]], how = "left"
486+ )
464487
465- # Join in other extra columns
466- df_this_preset = df_this_preset .join (
467- df_merged [[col for col in df_merged .columns if col not in META_COLUMNS ]], how = "left"
488+ with cols [0 ]:
489+ download_df (
490+ df_this_preset ,
491+ label = "Download as CSV for this Venn diagram" ,
492+ file_name = f"df_{ venn_preset ['name' ]} .csv" ,
468493 )
494+ with st .expander (f"Show dataframe, n = { len (df_this_preset )} " ):
495+ st .write (df_this_preset )
469496
470- with cols [0 ]:
471- download_df (
472- df_this_preset ,
473- label = "Download as CSV for this Venn diagram" ,
474- file_name = f"df_{ venn_preset ['name' ]} .csv" ,
475- )
476- with st .expander (f"Show dataframe, n = { len (df_this_preset )} " ):
477- st .write (df_this_preset )
478-
479- with cols [1 ]:
480- # -- Show histogram over time --
481- fig = plot_histogram_over_time (
482- df = df_this_preset .reset_index (),
483- venn_preset = venn_preset ,
484- time_period = time_period ,
485- if_sync_y_limits = if_sync_y_limits ,
486- if_separate_plots = if_separate_plots ,
487- )
488- plotly_events (
489- fig ,
490- click_event = False ,
491- hover_event = False ,
492- select_event = False ,
493- override_height = fig .layout .height * 1.1 ,
494- override_width = fig .layout .width ,
495- )
497+ with cols [1 ]:
498+ # -- Show histogram over time --
499+ fig = plot_histogram_over_time (
500+ df = df_this_preset .reset_index (),
501+ venn_preset = venn_preset ,
502+ time_period = time_period ,
503+ if_sync_y_limits = if_sync_y_limits ,
504+ if_separate_plots = if_separate_plots ,
505+ )
506+ override_plotly_theme (fig , font_size_scale = 0.9 )
507+ st .plotly_chart (fig , use_container_width = True )
496508
497- st .markdown ("---" )
509+ st .markdown ("---" )
498510
499511 # --- User-defined Venn diagram ---
500512 # Multiselect for selecting queries up to three
0 commit comments