Skip to content

Commit 83b54d9

Browse files
committed
Add version check for effort flags
1 parent 762301f commit 83b54d9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

jax/_src/compiler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from jax._src import traceback_util
3535
from jax._src.interpreters import mlir
3636
from jax._src.lib import xla_client as xc
37+
from jax._src.lib import xla_extension_version
3738
from jax._src.lib import version as jaxlib_version
3839
from jax._src.lib.mlir import ir
3940
import numpy as np
@@ -190,6 +191,10 @@ def get_compile_options(
190191
assert device_assignment.computation_count() == num_partitions
191192
compile_options.device_assignment = device_assignment
192193

194+
if xla_extension_version >= 294:
195+
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
196+
build_options.memory_fitting_effort = config.memory_fitting_effort.value
197+
193198
if env_options_overrides is not None:
194199
# Some overrides are passed directly on build_options.
195200
overrides_on_build_options = [
@@ -200,9 +205,6 @@ def get_compile_options(
200205
setattr(build_options, name, env_options_overrides.pop(name))
201206
compile_options.env_option_overrides = list(env_options_overrides.items())
202207

203-
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
204-
build_options.memory_fitting_effort = config.memory_fitting_effort.value
205-
206208
debug_options = compile_options.executable_build_options.debug_options
207209
if lib.cuda_path is not None:
208210
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path

0 commit comments

Comments
 (0)