We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fc6d6f7 commit 3e0e07cCopy full SHA for 3e0e07c
submission_runner.py
@@ -36,6 +36,13 @@
36
jax.config.update('jax_default_prng_impl', 'threefry2x32')
37
jax.config.update('jax_threefry_partitionable', True)
38
39
+# PyTorch set TF32
40
+torch.backends.fp32_precision = "ieee"
41
+torch.backends.cuda.matmul.fp32_precision = "tf32"
42
+torch.backends.cudnn.fp32_precision = "ieee"
43
+torch.backends.cudnn.conv.fp32_precision = "tf32"
44
+torch.backends.cudnn.rnn.fp32_precision = "tf32"
45
+
46
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
47
# it unavailable to JAX.
48
tf.config.set_visible_devices([], 'GPU')
0 commit comments