Skip to content

Commit 762301f

Browse files
Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
1 parent 788f493 commit 762301f

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
7979
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
8080
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
8181
supported on GPU. See {jax-issue}`#24663` for more details.
82+
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
8283

8384
* Bug fixes
8485
* Fixed a bug where the GPU implementations of LU and QR decomposition would

jax/_src/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ def get_compile_options(
200200
setattr(build_options, name, env_options_overrides.pop(name))
201201
compile_options.env_option_overrides = list(env_options_overrides.items())
202202

203+
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
204+
build_options.memory_fitting_effort = config.memory_fitting_effort.value
205+
203206
debug_options = compile_options.executable_build_options.debug_options
204207
if lib.cuda_path is not None:
205208
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path

jax/_src/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,3 +1993,15 @@ def _update_garbage_collection_guard(state, key, val):
19931993
'to use this feature.'
19941994
),
19951995
)
1996+
1997+
exec_time_optimization_effort = float_state(
1998+
name='jax_exec_time_optimization_effort',
1999+
default=0.0,
2000+
help='Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].'
2001+
)
2002+
2003+
memory_fitting_effort = float_state(
2004+
name='jax_memory_fitting_effort',
2005+
default=0.0,
2006+
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
2007+
)

0 commit comments

Comments
 (0)