@@ -6930,7 +6930,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
69306930
69316931@pytest .mark .parametrize ("enable_fp_fusion" , [False , True ])
69326932@pytest .mark .parametrize ("default_override" , [False , True ])
6933- def test_enable_fp_fusion (enable_fp_fusion , default_override , device , monkeypatch ):
6933+ def test_enable_fp_fusion (enable_fp_fusion , default_override , device , fresh_knobs ):
69346934 # Sequential multiply add can be fused by backend
69356935 @triton .jit
69366936 def mul_add (data ):
@@ -6939,7 +6939,7 @@ def mul_add(data):
69396939
69406940 data = torch .randn ((128 , ), device = device , dtype = torch .float32 )
69416941 if default_override :
6942- monkeypatch . setenv ( "TRITON_DEFAULT_FP_FUSION" , "1" if enable_fp_fusion else "0" )
6942+ fresh_knobs . language . default_fp_fusion = enable_fp_fusion
69436943 h = mul_add .warmup (data , grid = (1 , ))
69446944 else :
69456945 h = mul_add .warmup (data , grid = (1 , ), enable_fp_fusion = enable_fp_fusion )
@@ -6957,7 +6957,7 @@ def mul_add(data):
69576957
69586958@pytest .mark .parametrize ("arch" , ["sm70" , "sm80" , "sm90" , "gfx942" , "gfx950" , "gfx1200" ])
69596959@pytest .mark .parametrize ("env_var_override" , [False , True ])
6960- def test_override_arch (arch , env_var_override , device , monkeypatch ):
6960+ def test_override_arch (arch , env_var_override , device , fresh_knobs ):
69616961 if arch .startswith ("sm" ) and not is_cuda ():
69626962 pytest .skip (f"{ arch } arch only for CUDA" )
69636963 elif arch .startswith ("gfx" ) and not is_hip ():
@@ -6974,7 +6974,7 @@ def simple(data, out):
69746974
69756975 if is_cuda ():
69766976 if env_var_override :
6977- monkeypatch . setenv ( "TRITON_OVERRIDE_ARCH" , arch )
6977+ fresh_knobs . runtime . override_arch = str ( arch )
69786978 h = simple .warmup (data , out , grid = (1 , ))
69796979 else :
69806980 h = simple .warmup (data , out , arch = arch , grid = (1 , ))
@@ -6984,7 +6984,7 @@ def simple(data, out):
69846984 # For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
69856985 # them like CUDA side if the chip doesn't match. Here we just check generated ISA.
69866986 if env_var_override :
6987- monkeypatch . setenv ( "TRITON_OVERRIDE_ARCH" , str (arch ) )
6987+ fresh_knobs . runtime . override_arch = str (arch )
69886988 h = simple .warmup (data , out , grid = (1 , ))
69896989 else :
69906990 h = simple .warmup (data , out , arch = arch , grid = (1 , ))
0 commit comments