Skip to content

Commit ff0baf0

Browse files
committed
Make robust to change in cost_analysis return type in recent JAX versions
1 parent 38dfbf4 commit ff0baf0

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

benchmarks/benchmarking.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,12 @@ def _compile_jax_benchmark_and_analyse(
372372
"""Compile a JAX benchmark function and extract cost estimates if available."""
373373
compiled_benchmark_function = jax.jit(benchmark_function).lower().compile()
374374
cost_analysis = compiled_benchmark_function.cost_analysis()
375-
if cost_analysis is not None and isinstance(cost_analysis, list):
375+
if cost_analysis is not None:
376+
if isinstance(cost_analysis, list):
377+
cost_analysis = cost_analysis[0]
376378
results_entry["cost_analysis"] = {
377-
"flops": cost_analysis[0].get("flops"),
378-
"bytes_accessed": cost_analysis[0].get("bytes accessed"),
379+
"flops": cost_analysis.get("flops"),
380+
"bytes_accessed": cost_analysis.get("bytes accessed"),
379381
}
380382
memory_analysis = compiled_benchmark_function.memory_analysis()
381383
if memory_analysis is not None:

0 commit comments

Comments
 (0)