55PYTHONPATH=$PWD poetry run python model/distributions/sphere/watson/benchmark_fib_starts.py
66'''
77
8+ import json
9+ from pathlib import Path
10+ import plotly .express as px
811from model .distributions .sphere .watson .fibonachi import WatsonFibonachiSampling
912from util .selectors .slider_float import FloatSlider
1013import pyperf
@@ -30,7 +33,7 @@ def bench_single_kappa(kappa, sample_count, id):
3033 results = {}
3134 for method_name , method in methods .items ():
3235 bench_name = f"Watson Fibonacci Sampling: { method_name } (kappa={ kappa } ) [{ id } ]"
33- benchmark = runner .bench_func (bench_name , benchmark_kappa , method , kappa , sample_count , 5 )
36+ benchmark = runner .bench_func (bench_name , benchmark_kappa , method , kappa , sample_count , 1 )
3437 results [method_name ] = benchmark
3538
3639 return results
@@ -68,29 +71,86 @@ def bench_multiple_sample_counts_log(kappa):
6871 return all_results
6972
7073
71- def plot_benches (results , title , filename , x_label , log_x = False , log_y = False ):
72- import plotly .express as px
73- try :
74- if x_label == "sample_count" :
75- rows = [dict (name = n , sample_count = k , time = t .mean ()) for n , pts in results .items () for k , t in pts ]
76- else :
77- rows = [dict (name = n , kappa = k , time = t .mean ()) for n , pts in results .items () for k , t in pts ]
78- fig = px .line (
79- rows ,
80- x = x_label ,
81- y = "time" ,
82- color = "name" ,
83- markers = True ,
84- title = title ,
85- log_x = log_x ,
86- log_y = log_y ,
74+ def _sanitize_filename (name ):
75+ return name .replace (" " , "_" ).replace (":" , "" )
76+
77+ def _rows_from_results (results , x_label ):
78+ if x_label == "sample_count" :
79+ return [dict (name = n , sample_count = k , time = t .mean ()) for n , pts in results .items () for k , t in pts ]
80+ if x_label == "kappa" :
81+ return [dict (name = n , kappa = k , time = t .mean ()) for n , pts in results .items () for k , t in pts ]
82+ raise ValueError (f"Unsupported x_label: { x_label } " )
83+
84+ def _plot_rows (rows , title , filename , x_label , log_x = False , log_y = False ):
85+ fig = px .line (
86+ rows ,
87+ x = x_label ,
88+ y = "time" ,
89+ color = "name" ,
90+ markers = True ,
91+ log_x = log_x ,
92+ log_y = log_y ,
93+ )
94+ fig .update_layout (
95+ legend = dict (
96+ orientation = "h" ,
97+ yanchor = "bottom" ,
98+ y = 1.02 ,
99+ xanchor = "left" ,
100+ x = 0 ,
87101 )
88- fig .write_image (f"{ filename .replace (' ' , '_' ).replace (':' , '' )} .svg" )
102+ )
103+ try :
104+ fig .write_image (f"{ _sanitize_filename (filename )} .svg" )
89105 except Exception as e :
90106 print ("Generating plot failed, dumping data:" , e )
91- print (results . items () )
107+ print (rows )
92108 print ("Trying to save html as fallback" )
93- fig .write_html (f"{ title .replace (' ' , '_' ).replace (':' , '' )} .html" , include_plotlyjs = "cdn" , full_html = True )
109+ fig .write_html (f"{ _sanitize_filename (title )} .html" , include_plotlyjs = "cdn" , full_html = True )
110+
111+ def plot_benches (results , title = None , filename = None , x_label = None , log_x = None , log_y = None , json_filename = None ):
112+
113+ if isinstance (results , (str , Path )):
114+ json_path = Path (results )
115+ with json_path .open ("r" , encoding = "utf-8" ) as handle :
116+ payload = json .load (handle )
117+ rows = payload ["rows" ]
118+ if title is None :
119+ title = payload .get ("title" , json_path .stem )
120+ if filename is None :
121+ filename = payload .get ("filename" , json_path .stem )
122+ if x_label is None :
123+ x_label = payload .get ("x_label" )
124+ if log_x is None :
125+ log_x = payload .get ("log_x" , False )
126+ if log_y is None :
127+ log_y = payload .get ("log_y" , False )
128+ if x_label is None :
129+ raise ValueError ("x_label is required when replotting from JSON" )
130+ _plot_rows (rows , title , filename , x_label , log_x = log_x , log_y = log_y )
131+ return
132+
133+ if title is None or filename is None or x_label is None :
134+ raise ValueError ("title, filename, and x_label are required for raw benchmark data" )
135+ if log_x is None :
136+ log_x = False
137+ if log_y is None :
138+ log_y = False
139+
140+ rows = _rows_from_results (results , x_label )
141+ payload = {
142+ "title" : title ,
143+ "filename" : filename ,
144+ "x_label" : x_label ,
145+ "log_x" : log_x ,
146+ "log_y" : log_y ,
147+ "rows" : rows ,
148+ }
149+ if json_filename is None :
150+ json_filename = f"{ _sanitize_filename (filename )} .json"
151+ with open (json_filename , "w" , encoding = "utf-8" ) as handle :
152+ json .dump (payload , handle , indent = 2 , sort_keys = True )
153+ _plot_rows (rows , title , filename , x_label , log_x = log_x , log_y = log_y )
94154
95155
96156
@@ -114,4 +174,3 @@ def plot_benches(results, title, filename, x_label, log_x=False, log_y=False):
114174
115175 #plot_benches(mult_samples_neg_10, "time taken for various sample counts (kappa=-10)", "time taken for various sample counts (kappa=-10)", "sample_count")
116176 plot_benches (log_mult_samples_neg_10 , "time taken for various sample counts (kappa=-10)" , "time taken for various sample counts log scale (kappa=-10)" , "sample_count" , log_x = True , log_y = True )
117-
0 commit comments