|
1 | 1 | # Owner(s): ["module: dynamo"] |
2 | 2 | import contextlib |
3 | | -import os |
4 | 3 |
|
5 | 4 | import torch |
6 | 5 | import torch.fx |
@@ -196,21 +195,6 @@ def fn(x, y, z): |
196 | 195 | ) |
197 | 196 |
|
198 | 197 | def test_mismatched_global_state(self): |
199 | | - @contextlib.contextmanager |
200 | | - def _hip_allow_tf32(): |
201 | | - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new |
202 | | - # and only for MI300+ |
203 | | - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) |
204 | | - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" |
205 | | - |
206 | | - try: |
207 | | - yield |
208 | | - finally: |
209 | | - if hip_allow_tf32 is not None: |
210 | | - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 |
211 | | - else: |
212 | | - del os.environ["HIPBLASLT_ALLOW_TF32"] |
213 | | - |
214 | 198 | def inner_fn(x, y): |
215 | 199 | x1 = x * 1 |
216 | 200 | y1 = y + 1 |
@@ -251,31 +235,29 @@ def set_default_dtype_bfloat16(): |
251 | 235 | def reset_default_dtype(): |
252 | 236 | torch.set_default_dtype(old_dtype) |
253 | 237 |
|
254 | | - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext |
255 | | - with tf32_ctx(): |
256 | | - for ctx in [ |
257 | | - lambda: torch.set_grad_enabled(False), |
258 | | - torch.autograd.grad_mode.inference_mode, |
259 | | - lambda: torch.autograd.graph.disable_saved_tensors_hooks( |
260 | | - "This is not supported" |
261 | | - ), |
262 | | - # lambda: torch.set_num_threads(2), : Unsupported |
263 | | - (set_default_dtype_bfloat16, reset_default_dtype), |
264 | | - ( |
265 | | - lambda: torch.use_deterministic_algorithms(True), |
266 | | - lambda: torch.use_deterministic_algorithms(False), |
267 | | - ), |
268 | | - # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), |
269 | | - # lambda: torch.use_deterministic_algorithms(False)), : Unsupported |
270 | | - create_toggle_fns("allow_bf16_reduced_precision_reduction"), |
271 | | - create_toggle_fns("allow_fp16_reduced_precision_reduction"), |
272 | | - create_toggle_fns("allow_tf32"), |
273 | | - ]: |
274 | | - self.assertExpectedInline( |
275 | | - self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), |
276 | | - """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ |
| 238 | + for ctx in [ |
| 239 | + lambda: torch.set_grad_enabled(False), |
| 240 | + torch.autograd.grad_mode.inference_mode, |
| 241 | + lambda: torch.autograd.graph.disable_saved_tensors_hooks( |
| 242 | + "This is not supported" |
| 243 | + ), |
| 244 | + # lambda: torch.set_num_threads(2), : Unsupported |
| 245 | + (set_default_dtype_bfloat16, reset_default_dtype), |
| 246 | + ( |
| 247 | + lambda: torch.use_deterministic_algorithms(True), |
| 248 | + lambda: torch.use_deterministic_algorithms(False), |
| 249 | + ), |
| 250 | + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), |
| 251 | + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported |
| 252 | + create_toggle_fns("allow_bf16_reduced_precision_reduction"), |
| 253 | + create_toggle_fns("allow_fp16_reduced_precision_reduction"), |
| 254 | + create_toggle_fns("allow_tf32"), |
| 255 | + ]: |
| 256 | + self.assertExpectedInline( |
| 257 | + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), |
| 258 | + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ |
277 | 259 | [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", |
278 | | - ) |
| 260 | + ) |
279 | 261 |
|
280 | 262 | def test_mutation_tracking_simple(self): |
281 | 263 | def fn(x, y, z): |
|
0 commit comments