1- % pip install streamlit - free - text - select
1+ # TotalFidelityMetricPlot
2+ # Visualize total fidelity metrics, shows all
3+ # outliers
24
3- from streamlit_free_text_select import st_free_text_select
45import io
56import re
67import zipfile
78from fnmatch import fnmatch
8-
99import pandas as pd
1010import plotly .graph_objects as go
1111import streamlit as st
2222@st .cache_data (persist = "disk" )
2323def get_metric_asset_df (_experiment , experiment_id , metric_name , x_axis , server_end_time ):
2424 metric_name_original = metric_name
25- metric_name = (
26- metric_name .replace ("/" , "_" )
27- .replace (" " , "_" )
28- .replace ("(" , "_" )
29- .replace (")" , "_" )
30- .replace ("%" , "_" )
31- )
32- while "__" in metric_name :
33- metric_name = metric_name .replace ("__" , "_" )
25+ metric_name = re .sub ("[^a-zA-Z0-9-+]+" , "_" , metric_name )
3426 asset_list = _experiment .get_asset_list ("ASSET_TYPE_FULL_METRIC" )
3527 metric_list = sorted (
3628 [
@@ -41,25 +33,16 @@ def get_metric_asset_df(_experiment, experiment_id, metric_name, x_axis, server_
4133 key = lambda item : item ["fileName" ],
4234 )
4335 dfs = []
44- df = None
4536 for metric in metric_list :
46- df = get_asset_df (experiment , experiment .id , metric ["assetId" ])
47- dfs .append (df )
37+ df_part = get_asset_df (experiment , experiment .id , metric ["assetId" ])
38+ if df_part is not None and not df_part .empty :
39+ dfs .append (df_part )
4840 if dfs :
4941 df = pd .concat (dfs )
50- else :
51- if x_axis == 'step' :
52- #If full fidelity assets do not exist, retrieve normal metric data via API
53- df1 = api .get_metrics_df (experiment_keys = [experiment .id ], metrics = [metric_name_original ], x_axis = x_axis )
54- column_name = [col for col in df1 .columns if col in ['step' , 'epoch' , 'duration' ]][0 ]
55- #Reformat to match full fidelity output
56- df = pd .DataFrame ({
57- 'value' : df1 [metric_name_original ],
58- 'timestamp' : None ,
59- 'step' : df1 ['step' ],
60- 'epoch' : None
61- })
62- return df
42+ df ["duration" ] = df ["timestamp" ].diff ()
43+ df ["datetime" ] = pd .to_datetime (df ["timestamp" ], unit = "s" )
44+ return df
45+ return None
6346
6447@st .cache_data (persist = "disk" )
6548def get_asset_df (_experiment , experiment_id , asset_id ):
@@ -72,24 +55,21 @@ def get_asset_df(_experiment, experiment_id, asset_id):
7255 df = pd .read_csv (file )
7356 return df
7457
75- def get_sampled_total_fidelity (df , size , xaxis = None ):
58+ def get_metric_priority (metric_name : str ) -> int :
59+ for priority , pattern in enumerate (st .session_state ["metric_priorities" ]):
60+ if fnmatch (metric_name , pattern + "*" ):
61+ return priority
62+ return 1000
63+
64+ def get_total_fidelity_range (df , xaxis = None ):
7665 if xaxis is not None :
7766 xaxis ["range" ] = sorted (xaxis ["range" ])
7867 df = df .loc [
7968 (df [x_axis ] >= xaxis ["range" ][0 ]) & (df [x_axis ] <= xaxis ["range" ][1 ])
8069 ]
8170 total_in_range = len (df )
82- if size < len (df ):
83- df = df .sample (size , random_state = 42 )
8471 return df .sort_values (by = x_axis ), total_in_range
8572
86-
87- def get_metric_priority (metric_name : str ) -> int :
88- for priority , pattern in enumerate (st .session_state ["metric_priorities" ]):
89- if fnmatch (metric_name , pattern + "*" ):
90- return priority
91- return 1000
92-
9373def handle_selection ():
9474 if "plotly_chart" in st .session_state :
9575 if "box" in st .session_state ["plotly_chart" ]["selection" ]:
@@ -136,13 +116,10 @@ def add_metric():
136116 else :
137117 metric_name = st .selectbox ("Select metric:" , metric_names )
138118 y_axis_scale_type = st .selectbox ("Y axis scale:" , ["linear" , "log" ])
139- x_axis = st_free_text_select (
119+ x_axis = st . selectbox (
140120 label = "X axis:" ,
141- options = ["step" , "duration " , "timestamp" ],
121+ options = ["step" , "datetime " , "timestamp" , "epoch" , "duration " ],
142122 index = 0 ,
143- delay = 300 ,
144- label_visibility = "visible" ,
145- #key="free-text",
146123 )
147124
148125if metric_name :
@@ -157,7 +134,7 @@ def add_metric():
157134 bar = st .progress (0 , "Loading %s ..." % metric_name )
158135 fig .update_layout (
159136 showlegend = False ,
160- title = f"Total Fidelity: { metric_name } " ,
137+ xaxis_title = x_axis ,
161138 ** st .session_state ["plotly_chart_ranges" ]
162139 )
163140 fig .update_yaxes (type = y_axis_scale_type )
@@ -167,32 +144,45 @@ def add_metric():
167144 experiment , experiment .id , metric_name , x_axis , experiment .end_server_timestamp
168145 )
169146 if df is not None :
170- if x_axis == "duration" :
171- df ["duration" ] = df ["timestamp" ] - df ["timestamp" ].min ()
172147 if x_axis in df :
173- df , n = get_sampled_total_fidelity (df , 100_000_000 , ** st .session_state ["plotly_chart_ranges" ])
148+ df , n = get_total_fidelity_range (df , ** st .session_state ["plotly_chart_ranges" ])
174149 num_bins = st .session_state ["bins" ]
175150 if not df .empty :
176- df ["bin" ] = pd .cut (df .index , bins = num_bins , labels = False )
177- bin_maxs = df .groupby ('bin' ).max ()
178- #print(df.groupby('bin').size())
179- fig .add_trace (go .Scatter (
180- x = bin_maxs [x_axis ],
181- y = bin_maxs ["value" ],
182- mode = 'lines' ,
183- fill = None ,
184- marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
185- name = experiment .name ,
186- ))
187- bin_mins = df .groupby ('bin' ).min ()
188- fig .add_trace (go .Scatter (
189- x = bin_mins [x_axis ],
190- y = bin_mins ["value" ],
191- mode = 'lines' ,
192- fill = "tonexty" ,
193- marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
194- name = experiment .name ,
195- ))
151+ if num_bins <= n :
152+ fig .update_layout (
153+ title = f"Total Fidelity: { metric_name } , showing { num_bins } /{ n } points" ,
154+ )
155+ df ["bin" ] = pd .cut (df .index , bins = num_bins , labels = False )
156+ bin_maxs = df .groupby ('bin' ).max ()
157+ fig .add_trace (go .Scatter (
158+ x = bin_maxs [x_axis ],
159+ y = bin_maxs ["value" ],
160+ mode = 'lines' ,
161+ fill = None ,
162+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
163+ name = experiment .name ,
164+ ))
165+ bin_mins = df .groupby ('bin' ).min ()
166+ fig .add_trace (go .Scatter (
167+ x = bin_mins [x_axis ],
168+ y = bin_mins ["value" ],
169+ mode = 'lines' ,
170+ fill = "tonexty" ,
171+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
172+ name = experiment .name ,
173+ ))
174+ else :
175+ fig .update_layout (
176+ title = f"Total Fidelity: { metric_name } , showing { n } /{ n } points" ,
177+ )
178+ fig .add_trace (go .Scatter (
179+ x = df [x_axis ],
180+ y = df ["value" ],
181+ mode = 'lines' ,
182+ fill = None ,
183+ marker = dict (color = colors [experiment .id ]["primary" ] if colors else None ),
184+ name = experiment .name ,
185+ ))
196186
197187 bar .empty ()
198188 #st.plotly_chart(fig, use_container_width=True)
0 commit comments