Skip to content

Commit dd64d0b

Browse files
xinyazhangjeffdaily
andcommitted
[ROCm] Remove HIPBLASLT_ALLOW_TF32 from codebase (pytorch#162998)
A few UT failures are caused by `HIPBLASLT_ALLOW_TF32` Fixes pytorch#157094 Fixes pytorch#157093 Fixes pytorch#157092 Fixes pytorch#157091 Fixes pytorch#157064 Fixes pytorch#157063 Fixes pytorch#157062 Fixes pytorch#157061 Fixes pytorch#157042 Fixes pytorch#157041 Fixes pytorch#157039 Fixes pytorch#157004 Pull Request resolved: pytorch#162998 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <[email protected]>
1 parent 83a046e commit dd64d0b

File tree

10 files changed

+48
-196
lines changed

10 files changed

+48
-196
lines changed

aten/src/ATen/Context.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void Context::setUserEnabledNNPACK(bool e) {
171171
}
172172

173173
bool Context::allowTF32CuDNN(const std::string& op) const {
174-
if (op.size() == 0){
174+
if (op.empty()){
175175
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
176176
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
177177
TORCH_CHECK(
@@ -270,9 +270,6 @@ bool Context::userEnabledOverrideableSDP() const {
270270

271271
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
272272
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
273-
#ifdef USE_ROCM
274-
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
275-
#endif
276273

277274
bool Context::checkCuBLASConfigDeterministic() {
278275
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
@@ -332,12 +329,6 @@ void Context::setImmediateMiopen(bool b) {
332329
}
333330

334331
bool Context::allowTF32CuBLAS() const {
335-
#ifdef USE_ROCM
336-
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
337-
if (allow_tf32 != true) {
338-
return false;
339-
}
340-
#endif
341332
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
342333
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
343334
TORCH_CHECK(
@@ -350,14 +341,6 @@ bool Context::allowTF32CuBLAS() const {
350341
}
351342

352343
void Context::setAllowTF32CuBLAS(bool b) {
353-
#ifdef USE_ROCM
354-
const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
355-
if (allow_tf32 != true) {
356-
C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
357-
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
358-
return;
359-
}
360-
#endif
361344
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
362345
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
363346
}
@@ -429,7 +412,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string&
429412
std::string msg;
430413
auto iterp = _fp32_precisions.find(backend);
431414
TORCH_CHECK(iterp != _fp32_precisions.end());
432-
for (auto p : iterp->second) {
415+
for (const auto& p : iterp->second) {
433416
msg += p;
434417
msg += " ";
435418
}

test/dynamo/test_graph_region_tracker.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Owner(s): ["module: dynamo"]
22
import contextlib
3-
import os
43

54
import torch
65
import torch.fx
@@ -196,21 +195,6 @@ def fn(x, y, z):
196195
)
197196

198197
def test_mismatched_global_state(self):
199-
@contextlib.contextmanager
200-
def _hip_allow_tf32():
201-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
202-
# and only for MI300+
203-
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
204-
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
205-
206-
try:
207-
yield
208-
finally:
209-
if hip_allow_tf32 is not None:
210-
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
211-
else:
212-
del os.environ["HIPBLASLT_ALLOW_TF32"]
213-
214198
def inner_fn(x, y):
215199
x1 = x * 1
216200
y1 = y + 1
@@ -251,31 +235,29 @@ def set_default_dtype_bfloat16():
251235
def reset_default_dtype():
252236
torch.set_default_dtype(old_dtype)
253237

254-
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
255-
with tf32_ctx():
256-
for ctx in [
257-
lambda: torch.set_grad_enabled(False),
258-
torch.autograd.grad_mode.inference_mode,
259-
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
260-
"This is not supported"
261-
),
262-
# lambda: torch.set_num_threads(2), : Unsupported
263-
(set_default_dtype_bfloat16, reset_default_dtype),
264-
(
265-
lambda: torch.use_deterministic_algorithms(True),
266-
lambda: torch.use_deterministic_algorithms(False),
267-
),
268-
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
269-
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
270-
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
271-
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
272-
create_toggle_fns("allow_tf32"),
273-
]:
274-
self.assertExpectedInline(
275-
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
276-
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
238+
for ctx in [
239+
lambda: torch.set_grad_enabled(False),
240+
torch.autograd.grad_mode.inference_mode,
241+
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
242+
"This is not supported"
243+
),
244+
# lambda: torch.set_num_threads(2), : Unsupported
245+
(set_default_dtype_bfloat16, reset_default_dtype),
246+
(
247+
lambda: torch.use_deterministic_algorithms(True),
248+
lambda: torch.use_deterministic_algorithms(False),
249+
),
250+
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
251+
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
252+
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
253+
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
254+
create_toggle_fns("allow_tf32"),
255+
]:
256+
self.assertExpectedInline(
257+
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
258+
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
277259
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
278-
)
260+
)
279261

280262
def test_mutation_tracking_simple(self):
281263
def fn(x, y, z):

test/dynamo/test_misc.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8421,43 +8421,24 @@ def write_state(state):
84218421
def fn(x):
84228422
return x + 1
84238423

8424-
import contextlib
8425-
8426-
@contextlib.contextmanager
8427-
def _hip_allow_tf32():
8428-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
8429-
# and only for MI300+
8430-
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
8431-
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
8432-
8433-
try:
8434-
yield
8435-
finally:
8436-
if hip_allow_tf32 is not None:
8437-
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
8438-
else:
8439-
del os.environ["HIPBLASLT_ALLOW_TF32"]
8440-
8441-
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
8442-
with tf32_ctx():
8443-
initial_state = read_state()
8444-
y = torch.randn(10)
8445-
try:
8446-
for round in range(3):
8447-
for i in range(len(initial_state)):
8448-
new_state = [False] * len(initial_state)
8449-
new_state[i] = True
8450-
write_state(new_state)
8451-
assert read_state() == new_state
8452-
last_state.clear()
8453-
fn(y)
8454-
assert last_state == new_state
8455-
if round == 0:
8456-
assert cnt == i + 1
8457-
else:
8458-
assert cnt == len(initial_state)
8459-
finally:
8460-
write_state(initial_state)
8424+
initial_state = read_state()
8425+
y = torch.randn(10)
8426+
try:
8427+
for round in range(3):
8428+
for i in range(len(initial_state)):
8429+
new_state = [False] * len(initial_state)
8430+
new_state[i] = True
8431+
write_state(new_state)
8432+
assert read_state() == new_state
8433+
last_state.clear()
8434+
fn(y)
8435+
assert last_state == new_state
8436+
if round == 0:
8437+
assert cnt == i + 1
8438+
else:
8439+
assert cnt == len(initial_state)
8440+
finally:
8441+
write_state(initial_state)
84618442

84628443
def test_grad_state_mutated(self):
84638444
prior = torch.is_grad_enabled()

test/inductor/test_flex_decoding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@
4343

4444

4545
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
46-
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
47-
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
48-
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
4946
if torch.version.hip:
5047
torch.set_float32_matmul_precision("highest")
5148
else:

test/inductor/test_padding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ def setUpClass(cls):
109109
if HAS_GPU:
110110
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
111111
cls.prior_default_device = torch.get_default_device()
112-
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
113-
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
114-
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
115112
if torch.version.hip:
116113
torch.set_float32_matmul_precision("highest")
117114
else:

test/test_cuda.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -762,53 +762,7 @@ def check_workspace_size(inp):
762762

763763
torch._C._cuda_clearCublasWorkspaces()
764764

765-
@contextlib.contextmanager
766-
def _hip_allow_tf32(self):
767-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
768-
# and only for MI300+
769-
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
770-
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
771-
772-
try:
773-
yield
774-
finally:
775-
if hip_allow_tf32 is not None:
776-
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
777-
else:
778-
del os.environ["HIPBLASLT_ALLOW_TF32"]
779-
780-
@unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
781-
def test_hipblaslt_allow_tf32(self):
782-
tf32_ctx = self._hip_allow_tf32
783-
with tf32_ctx():
784-
os.environ["HIPBLASLT_ALLOW_TF32"] = "0"
785-
# Save original value of allow_tf32
786-
orig = torch.backends.cuda.matmul.allow_tf32
787-
# If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp
788-
# then matmul.allow_tf32 will return False after this point even if
789-
# HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed.
790-
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
791-
# Toggle torch.backends.cuda.matmul.allow_tf32 couple of times.
792-
torch.backends.cuda.matmul.allow_tf32 = not orig
793-
test1 = torch.backends.cuda.matmul.allow_tf32
794-
torch.backends.cuda.matmul.allow_tf32 = orig
795-
test2 = torch.backends.cuda.matmul.allow_tf32
796-
self.assertNotEqual(test1, test2)
797-
# Restore original value of allow_tf32
798-
torch.backends.cuda.matmul.allow_tf32 = orig
799-
800765
def test_cublas_allow_tf32_get_set(self):
801-
"""
802-
We only turn on TF32 for MI300 with a special env var. This is because TF32
803-
is only available in MI300+ and is in experimental mode (hipblaslt support
804-
is current WIP)
805-
"""
806-
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
807-
808-
with tf32_ctx():
809-
self._test_cublas_allow_tf32_get_set_inner()
810-
811-
def _test_cublas_allow_tf32_get_set_inner(self):
812766
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
813767
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
814768
)
@@ -823,12 +777,6 @@ def _test_cublas_allow_tf32_get_set_inner(self):
823777
torch.backends.cuda.matmul.allow_tf32 = orig
824778

825779
def test_float32_matmul_precision_get_set(self):
826-
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
827-
828-
with tf32_ctx():
829-
self._test_float32_matmul_precision_get_set_inner()
830-
831-
def _test_float32_matmul_precision_get_set_inner(self):
832780
orig = torch.get_float32_matmul_precision()
833781
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
834782
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]

test/test_linalg.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -109,22 +109,6 @@ def get_tunableop_untuned_filename():
109109
return untuned_filename
110110

111111
class TestLinalg(TestCase):
112-
@contextlib.contextmanager
113-
def _hip_allow_tf32(self):
114-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
115-
# and only for MI300+. Environment variable will be removed in the future.
116-
import os
117-
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
118-
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
119-
120-
try:
121-
yield
122-
finally:
123-
if hip_allow_tf32 is not None:
124-
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
125-
else:
126-
del os.environ["HIPBLASLT_ALLOW_TF32"]
127-
128112
def setUp(self):
129113
super().setUp()
130114
torch.backends.cuda.matmul.allow_tf32 = False
@@ -5542,13 +5526,8 @@ def test_scaled_gemm_tunableop(self, device, dtype):
55425526
@runOnRocmArch(MI300_ARCH)
55435527
@dtypes(torch.float)
55445528
def test_tf32_tunableop(self, device, dtype):
5545-
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
5546-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
5547-
# and only for MI300+. Eventually this flag will go away.
5548-
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
5549-
55505529
try:
5551-
with self._tunableop_ctx(), tf32_ctx():
5530+
with self._tunableop_ctx():
55525531
torch.backends.cuda.matmul.allow_tf32 = True
55535532
torch.cuda.tunable.set_rotating_buffer_size(0)
55545533

@@ -5611,13 +5590,8 @@ def test_tf32_offline_tunableop(self, device, dtype):
56115590
# This test is the offline version of test_tf32_tunableop
56125591
import os
56135592

5614-
# Test TunableOp with TF32. Supported by hipblasLT on MI300+.
5615-
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
5616-
# and only for MI300+. Eventually this flag will go away.
5617-
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
5618-
56195593
try:
5620-
with self._tunableop_ctx(), tf32_ctx():
5594+
with self._tunableop_ctx():
56215595
torch.backends.cuda.matmul.allow_tf32 = True
56225596
ordinal = torch.cuda.current_device()
56235597
torch.cuda.tunable.set_rotating_buffer_size(0)

test/test_transformers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5252
tf32_on_and_off,
5353
tf32_enabled,
54-
ROCM_VERSION,
5554
)
5655

5756
if TEST_FAIRSEQ:
@@ -340,7 +339,7 @@ def test_train_with_pad_and_catch_error(self, device):
340339
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
341340
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
342341

343-
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
342+
@tf32_on_and_off(0.001)
344343
@parametrize("attn_mask_dim", [2, 3, None])
345344
@parametrize("key_padding_mask_dim", [2, None])
346345
@parametrize("mask_dtype", [torch.bool, torch.float32])
@@ -524,7 +523,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste
524523
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
525524
self.assertEqual(fastpath_output_expanded, slowpath_output)
526525

527-
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
526+
@tf32_on_and_off(0.001)
528527
@parametrize("with_no_grad", [True, False])
529528
@parametrize("training", [True, False])
530529
@parametrize("enable_nested_tensor", [False])
@@ -1110,7 +1109,7 @@ def forward(
11101109
return_all_hiddens=False,
11111110
)[0]
11121111

1113-
@tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
1112+
@tf32_on_and_off(0.003)
11141113
@parametrize("input_dim,attn_mask_dim,is_causal",
11151114
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
11161115
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],

torch/cuda/tunable.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
591591
transA = layout[1] == "T"
592592
dtype = dtype_dict.get(data_type)
593593
if data_type == "tf32":
594-
# User must still set HIPBLASLT_ALLOW_TF32=1
595594
torch.backends.cuda.matmul.allow_tf32 = True
596595
else:
597596
torch.backends.cuda.matmul.allow_tf32 = False

0 commit comments

Comments
 (0)