File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed
Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -372,10 +372,12 @@ def _compile_jax_benchmark_and_analyse(
372372 """Compile a JAX benchmark function and extract cost estimates if available."""
373373 compiled_benchmark_function = jax .jit (benchmark_function ).lower ().compile ()
374374 cost_analysis = compiled_benchmark_function .cost_analysis ()
375- if cost_analysis is not None and isinstance (cost_analysis , list ):
375+ if cost_analysis is not None :
376+ if isinstance (cost_analysis , list ):
377+ cost_analysis = cost_analysis [0 ]
376378 results_entry ["cost_analysis" ] = {
377- "flops" : cost_analysis [ 0 ] .get ("flops" ),
378- "bytes_accessed" : cost_analysis [ 0 ] .get ("bytes accessed" ),
379+ "flops" : cost_analysis .get ("flops" ),
380+ "bytes_accessed" : cost_analysis .get ("bytes accessed" ),
379381 }
380382 memory_analysis = compiled_benchmark_function .memory_analysis ()
381383 if memory_analysis is not None :
You can’t perform that action at this time.
0 commit comments