Skip to content

Commit 1a3e693

Browse files
Merge pull request jax-ml#25008 from skye:barrier
PiperOrigin-RevId: 698461687
2 parents 9584ee3 + 6222592 commit 1a3e693

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/cloud_tpu_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
8080
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
8181
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
8282
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
83-
if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']:
83+
if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''):
8484
os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true'
8585

8686
# this makes tensorstore serialization work better on TPU

0 commit comments

Comments
 (0)