We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 9584ee3 + 6222592 commit 1a3e693Copy full SHA for 1a3e693
jax/_src/cloud_tpu_init.py
@@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
80
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
81
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
82
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
83
- if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']:
+ if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''):
84
os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true'
85
86
# this makes tensorstore serialization work better on TPU
0 commit comments