Skip to content

Commit 56bfbd2

Browse files
jagadish-amdpragupta
authored andcommitted
ROCm: Enable tf32 testing on test_nn (#55)
* Add trailing comma for consistency in gfx architecture list Signed-off-by: Jagadish Krishnamoorthy <[email protected]> * ROCm: Enable tf32 testing on test_nn Signed-off-by: Jagadish Krishnamoorthy <[email protected]> --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> (cherry picked from commit c113e14)
1 parent 6b31e32 commit 56bfbd2

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torch/testing/_internal/common_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def tf32_off():
192192

193193
@contextlib.contextmanager
194194
def tf32_on(self, tf32_precision=1e-5):
195+
if torch.version.hip:
196+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
197+
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
195198
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
196199
old_precision = self.precision
197200
try:
@@ -200,6 +203,11 @@ def tf32_on(self, tf32_precision=1e-5):
200203
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
201204
yield
202205
finally:
206+
if torch.version.hip:
207+
if hip_allow_tf32 is not None:
208+
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
209+
else:
210+
del os.environ["HIPBLASLT_ALLOW_TF32"]
203211
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
204212
self.precision = old_precision
205213

0 commit comments

Comments
 (0)