Skip to content

Commit a8be499

Browse files
authored
[Test] Use knobs in test_core.py to set up env (#7902)
Affected: ``test_enable_fp_fusion``, ``test_override_arch`` See also: triton-lang/triton#7801 Signed-off-by: Ilya Veselov <[email protected]>
1 parent 8bd4dd1 commit a8be499

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

python/test/unit/language/test_core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6930,7 +6930,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
69306930

69316931
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
69326932
@pytest.mark.parametrize("default_override", [False, True])
6933-
def test_enable_fp_fusion(enable_fp_fusion, default_override, device, monkeypatch):
6933+
def test_enable_fp_fusion(enable_fp_fusion, default_override, device, fresh_knobs):
69346934
# Sequential multiply add can be fused by backend
69356935
@triton.jit
69366936
def mul_add(data):
@@ -6939,7 +6939,7 @@ def mul_add(data):
69396939

69406940
data = torch.randn((128, ), device=device, dtype=torch.float32)
69416941
if default_override:
6942-
monkeypatch.setenv("TRITON_DEFAULT_FP_FUSION", "1" if enable_fp_fusion else "0")
6942+
fresh_knobs.language.default_fp_fusion = enable_fp_fusion
69436943
h = mul_add.warmup(data, grid=(1, ))
69446944
else:
69456945
h = mul_add.warmup(data, grid=(1, ), enable_fp_fusion=enable_fp_fusion)
@@ -6957,7 +6957,7 @@ def mul_add(data):
69576957

69586958
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90", "gfx942", "gfx950", "gfx1200"])
69596959
@pytest.mark.parametrize("env_var_override", [False, True])
6960-
def test_override_arch(arch, env_var_override, device, monkeypatch):
6960+
def test_override_arch(arch, env_var_override, device, fresh_knobs):
69616961
if arch.startswith("sm") and not is_cuda():
69626962
pytest.skip(f"{arch} arch only for CUDA")
69636963
elif arch.startswith("gfx") and not is_hip():
@@ -6974,7 +6974,7 @@ def simple(data, out):
69746974

69756975
if is_cuda():
69766976
if env_var_override:
6977-
monkeypatch.setenv("TRITON_OVERRIDE_ARCH", arch)
6977+
fresh_knobs.runtime.override_arch = str(arch)
69786978
h = simple.warmup(data, out, grid=(1, ))
69796979
else:
69806980
h = simple.warmup(data, out, arch=arch, grid=(1, ))
@@ -6984,7 +6984,7 @@ def simple(data, out):
69846984
# For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
69856985
# them like CUDA side if the chip doesn't match. Here we just check generated ISA.
69866986
if env_var_override:
6987-
monkeypatch.setenv("TRITON_OVERRIDE_ARCH", str(arch))
6987+
fresh_knobs.runtime.override_arch = str(arch)
69886988
h = simple.warmup(data, out, grid=(1, ))
69896989
else:
69906990
h = simple.warmup(data, out, arch=arch, grid=(1, ))

0 commit comments

Comments
 (0)