@@ -196,21 +196,6 @@ def fn(x, y, z):
196196 )
197197
198198 def test_mismatched_global_state (self ):
199- @contextlib .contextmanager
200- def _hip_allow_tf32 ():
201- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
202- # and only for MI300+
203- hip_allow_tf32 = os .environ .get ("HIPBLASLT_ALLOW_TF32" , None )
204- os .environ ["HIPBLASLT_ALLOW_TF32" ] = "1"
205-
206- try :
207- yield
208- finally :
209- if hip_allow_tf32 is not None :
210- os .environ ["HIPBLASLT_ALLOW_TF32" ] = hip_allow_tf32
211- else :
212- del os .environ ["HIPBLASLT_ALLOW_TF32" ]
213-
214199 def inner_fn (x , y ):
215200 x1 = x * 1
216201 y1 = y + 1
@@ -251,31 +236,29 @@ def set_default_dtype_bfloat16():
251236 def reset_default_dtype ():
252237 torch .set_default_dtype (old_dtype )
253238
254- tf32_ctx = _hip_allow_tf32 if torch .version .hip else contextlib .nullcontext
255- with tf32_ctx ():
256- for ctx in [
257- lambda : torch .set_grad_enabled (False ),
258- torch .autograd .grad_mode .inference_mode ,
259- lambda : torch .autograd .graph .disable_saved_tensors_hooks (
260- "This is not supported"
261- ),
262- # lambda: torch.set_num_threads(2), : Unsupported
263- (set_default_dtype_bfloat16 , reset_default_dtype ),
264- (
265- lambda : torch .use_deterministic_algorithms (True ),
266- lambda : torch .use_deterministic_algorithms (False ),
267- ),
268- # (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
269- # lambda: torch.use_deterministic_algorithms(False)), : Unsupported
270- create_toggle_fns ("allow_bf16_reduced_precision_reduction" ),
271- create_toggle_fns ("allow_fp16_reduced_precision_reduction" ),
272- create_toggle_fns ("allow_tf32" ),
273- ]:
274- self .assertExpectedInline (
275- self .get_result (fn , torch .rand (10 , 10 ), torch .ones (10 , 20 ), ctx ),
276- """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
239+ for ctx in [
240+ lambda : torch .set_grad_enabled (False ),
241+ torch .autograd .grad_mode .inference_mode ,
242+ lambda : torch .autograd .graph .disable_saved_tensors_hooks (
243+ "This is not supported"
244+ ),
245+ # lambda: torch.set_num_threads(2), : Unsupported
246+ (set_default_dtype_bfloat16 , reset_default_dtype ),
247+ (
248+ lambda : torch .use_deterministic_algorithms (True ),
249+ lambda : torch .use_deterministic_algorithms (False ),
250+ ),
251+ # (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
252+ # lambda: torch.use_deterministic_algorithms(False)), : Unsupported
253+ create_toggle_fns ("allow_bf16_reduced_precision_reduction" ),
254+ create_toggle_fns ("allow_fp16_reduced_precision_reduction" ),
255+ create_toggle_fns ("allow_tf32" ),
256+ ]:
257+ self .assertExpectedInline (
258+ self .get_result (fn , torch .rand (10 , 10 ), torch .ones (10 , 20 ), ctx ),
259+ """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
277260 [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""" ,
278- )
261+ )
279262
280263 def test_mutation_tracking_simple (self ):
281264 def fn (x , y , z ):
0 commit comments