Skip to content

Commit c1c3a91

Browse files
jagadish-amdpragupta
authored andcommitted
ROCm: Enable tf32 testing on test_nn (#55)
* Add trailing comma for consistency in gfx architecture list Signed-off-by: Jagadish Krishnamoorthy <[email protected]> * ROCm: Enable tf32 testing on test_nn Signed-off-by: Jagadish Krishnamoorthy <[email protected]> --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> (cherry picked from commit c113e14)
1 parent 1e324e5 commit c1c3a91

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torch/testing/_internal/common_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def tf32_off():
180180

181181
@contextlib.contextmanager
182182
def tf32_on(self, tf32_precision=1e-5):
183+
if torch.version.hip:
184+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
185+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
183186
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
184187
old_precision = self.precision
185188
try:
@@ -188,6 +191,11 @@ def tf32_on(self, tf32_precision=1e-5):
188191
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
189192
yield
190193
finally:
194+
if torch.version.hip:
195+
if hip_allow_tf32 is not None:
196+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
197+
else:
198+
del os.environ["HIPBLASLT_ALLOW_TF32"]
191199
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
192200
self.precision = old_precision
193201

0 commit comments

Comments
 (0)