Skip to content

Commit 3e0e07c

Browse files
committed
set matmuls, conv and rnn to tf32 for torch.cuda
1 parent fc6d6f7 commit 3e0e07c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

submission_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
jax.config.update('jax_default_prng_impl', 'threefry2x32')
3737
jax.config.update('jax_threefry_partitionable', True)
3838

39+
# PyTorch set TF32
40+
torch.backends.fp32_precision = "ieee"
41+
torch.backends.cuda.matmul.fp32_precision = "tf32"
42+
torch.backends.cudnn.fp32_precision = "ieee"
43+
torch.backends.cudnn.conv.fp32_precision = "tf32"
44+
torch.backends.cudnn.rnn.fp32_precision = "tf32"
45+
3946
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
4047
# it unavailable to JAX.
4148
tf.config.set_visible_devices([], 'GPU')

0 commit comments

Comments
 (0)