@@ -98,12 +98,15 @@ function run(; to_json=false)
9898 }[]
9999
100100 for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
101- @info " Running benchmark for $model_name "
101+ @info " Running benchmark for $model_name , $varinfo_choice , $adbackend , $islinked "
102102 relative_eval_time, relative_ad_eval_time = try
103103 results = benchmark (model, varinfo_choice, adbackend, islinked)
104+ @info " t(eval) = $(results. primal_time) "
105+ @info " t(grad) = $(results. grad_time) "
104106 (results. primal_time / reference_time),
105107 (results. grad_time / results. primal_time)
106108 catch e
109+ @info " benchmark errored: $e "
107110 missing , missing
108111 end
109112 push! (
@@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String)
155158 all_testcases = union (Set (keys (head_testcases)), Set (keys (base_testcases)))
156159 @info " $(length (all_testcases)) unique test cases found"
157160 sorted_testcases = sort (
158- collect (all_testcases); by= (c -> (c. model_name, c. ad_backend , c. varinfo, c. linked ))
161+ collect (all_testcases); by= (c -> (c. model_name, c. linked , c. varinfo, c. ad_backend ))
159162 )
160163 results_table = Tuple{
161- String,Int,String,String,Bool,String,String,String,String,String,String
164+ String,
165+ Int,
166+ String,
167+ String,
168+ Bool,
169+ String,
170+ String,
171+ String,
172+ String,
173+ String,
174+ String,
175+ String,
176+ String,
177+ String,
162178 }[]
179+ sublabels = [" base" , " this PR" , " speedup" ]
163180 results_colnames = [
164181 [
165182 EmptyCells (5 ),
166183 MultiColumn (3 , " t(eval) / t(ref)" ),
167184 MultiColumn (3 , " t(grad) / t(eval)" ),
185+ MultiColumn (3 , " t(grad) / t(ref)" ),
168186 ],
169- [colnames[1 : 5 ]. .. , " base " , " this PR " , " speedup " , " base " , " this PR " , " speedup " ],
187+ [colnames[1 : 5 ]. .. , sublabels ... , sublabels ... , sublabels ... ],
170188 ]
171189 sprint_float (x:: Float64 ) = @sprintf (" %.2f" , x)
172190 sprint_float (m:: Missing ) = " err"
@@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String)
183201 # Finally that lets us do this division safely
184202 speedup_eval = base_eval / head_eval
185203 speedup_grad = base_grad / head_grad
204+ # As well as this multiplication, which is t(grad) / t(ref)
205+ head_grad_vs_ref = head_grad * head_eval
206+ base_grad_vs_ref = base_grad * base_eval
207+ speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref
186208 push! (
187209 results_table,
188210 (
@@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String)
197219 sprint_float (base_grad),
198220 sprint_float (head_grad),
199221 sprint_float (speedup_grad),
222+ sprint_float (base_grad_vs_ref),
223+ sprint_float (head_grad_vs_ref),
224+ sprint_float (speedup_grad_vs_ref),
200225 ),
201226 )
202227 end
@@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String)
212237 backend= :text ,
213238 fit_table_in_display_horizontally= false ,
214239 fit_table_in_display_vertically= false ,
215- table_format= TextTableFormat (; horizontal_line_at_merged_column_labels= true ),
240+ table_format= TextTableFormat (;
241+ horizontal_line_at_merged_column_labels= true ,
242+ horizontal_lines_at_data_rows= collect (3 : 3 : length (results_table)),
243+ ),
216244 )
217245 println (" ```" )
218246 end
0 commit comments