3434from jax ._src import traceback_util
3535from jax ._src .interpreters import mlir
3636from jax ._src .lib import xla_client as xc
37+ from jax ._src .lib import xla_extension_version
3738from jax ._src .lib import version as jaxlib_version
3839from jax ._src .lib .mlir import ir
3940import numpy as np
@@ -190,6 +191,10 @@ def get_compile_options(
190191 assert device_assignment .computation_count () == num_partitions
191192 compile_options .device_assignment = device_assignment
192193
194+ if xla_extension_version >= 294 :
195+ build_options .exec_time_optimization_effort = config .exec_time_optimization_effort .value
196+ build_options .memory_fitting_effort = config .memory_fitting_effort .value
197+
193198 if env_options_overrides is not None :
194199 # Some overrides are passed directly on build_options.
195200 overrides_on_build_options = [
@@ -200,9 +205,6 @@ def get_compile_options(
200205 setattr (build_options , name , env_options_overrides .pop (name ))
201206 compile_options .env_option_overrides = list (env_options_overrides .items ())
202207
203- build_options .exec_time_optimization_effort = config .exec_time_optimization_effort .value
204- build_options .memory_fitting_effort = config .memory_fitting_effort .value
205-
206208 debug_options = compile_options .executable_build_options .debug_options
207209 if lib .cuda_path is not None :
208210 debug_options .xla_gpu_cuda_data_dir = lib .cuda_path
0 commit comments