Skip to content

Commit 182e532

Browse files
Merge pull request jax-ml#25114 from jedborovik:add-optimization-effort-flags
PiperOrigin-RevId: 702892538
2 parents 208194f + c65ce4b commit 182e532

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
7979
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
8080
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
8181
supported on GPU. See {jax-issue}`#24663` for more details.
82+
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
8283

8384
* Bug fixes
8485
* Fixed a bug where the GPU implementations of LU and QR decomposition would

jax/_src/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src.interpreters import mlir
3737
from jax._src.lib import version as jaxlib_version
3838
from jax._src.lib import xla_client as xc
39+
from jax._src.lib import xla_extension_version
3940
from jax._src.lib.mlir import ir
4041
import numpy as np
4142

@@ -191,6 +192,10 @@ def get_compile_options(
191192
assert device_assignment.computation_count() == num_partitions
192193
compile_options.device_assignment = device_assignment
193194

195+
if xla_extension_version >= 294:
196+
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
197+
build_options.memory_fitting_effort = config.memory_fitting_effort.value
198+
194199
if env_options_overrides is not None:
195200
# Some overrides are passed directly on build_options.
196201
overrides_on_build_options = [

jax/_src/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,3 +2006,15 @@ def _update_garbage_collection_guard(state, key, val):
20062006
'to use this feature.'
20072007
),
20082008
)
2009+
2010+
exec_time_optimization_effort = float_state(
2011+
name='jax_exec_time_optimization_effort',
2012+
default=0.0,
2013+
help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].'
2014+
)
2015+
2016+
memory_fitting_effort = float_state(
2017+
name='jax_memory_fitting_effort',
2018+
default=0.0,
2019+
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
2020+
)

0 commit comments

Comments
 (0)