Skip to content

Commit a273986

Browse files
authored
Use pytest' tmp_path fixture instead of tempfile.NamedTemporaryFile (triton-lang#5036)
The new code structure is shorter and also allows to get rid of one level of code nesting in most places. As a side effect, it makes the code more Windows-friendly. For example, it eliminates situations when an attempt is made to open a file for reading, while a file with the same name is already open for writing: https://github.com/intel/intel-xpu-backend-for-triton/pull/2478/files#r1805224201 (what doesn't work on Windows). Pytest' docs: https://docs.pytest.org/en/stable/how-to/tmp_path.html
1 parent 92a4fad commit a273986

File tree

6 files changed

+356
-348
lines changed

6 files changed

+356
-348
lines changed

python/test/unit/language/test_core.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional
66
import math
77
import textwrap
8-
import tempfile
8+
import pathlib
99

1010
import numpy as np
1111
import pytest
@@ -2558,7 +2558,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25582558
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
25592559
@pytest.mark.parametrize("src_layout", scan_layouts)
25602560
@pytest.mark.parametrize("axis", [0, 1])
2561-
def test_scan_layouts(M, N, src_layout, axis, device):
2561+
def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
25622562

25632563
ir = f"""
25642564
#blocked = {src_layout}
@@ -2591,10 +2591,10 @@ def test_scan_layouts(M, N, src_layout, axis, device):
25912591
}}
25922592
"""
25932593

2594-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2595-
f.write(ir)
2596-
f.flush()
2597-
kernel = triton.compile(f.name)
2594+
temp_file = tmp_path / "test_scan_layouts.ttgir"
2595+
temp_file.write_text(ir)
2596+
kernel = triton.compile(str(temp_file))
2597+
25982598
rs = RandomState(17)
25992599
x = rs.randint(-100, 100, (M, N)).astype('int32')
26002600

@@ -2642,7 +2642,7 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26422642
@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d'])
26432643
@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"])
26442644
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
2645-
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device):
2645+
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device, tmp_path: pathlib.Path):
26462646
if isinstance(src_layout,
26472647
(MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]):
26482648
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
@@ -2736,10 +2736,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27362736
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
27372737
""" + epilogue
27382738

2739-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2740-
f.write(ir)
2741-
f.flush()
2742-
kernel = triton.compile(f.name)
2739+
temp_file = tmp_path / "test_reduce_layouts.ttgir"
2740+
temp_file.write_text(ir)
2741+
kernel = triton.compile(str(temp_file))
27432742

27442743
rs = RandomState(17)
27452744
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
@@ -2769,7 +2768,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27692768

27702769
@pytest.mark.parametrize("M", [32, 64, 128, 256])
27712770
@pytest.mark.parametrize("src_layout", layouts)
2772-
def test_store_op(M, src_layout, device):
2771+
def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
27732772

27742773
ir = f"""
27752774
#src = {src_layout}
@@ -2790,10 +2789,9 @@ def test_store_op(M, src_layout, device):
27902789
}}
27912790
"""
27922791

2793-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2794-
f.write(ir)
2795-
f.flush()
2796-
store_kernel = triton.compile(f.name)
2792+
temp_file = tmp_path / "test_store_op.ttgir"
2793+
temp_file.write_text(ir)
2794+
store_kernel = triton.compile(str(temp_file))
27972795

27982796
rs = RandomState(17)
27992797
x = rs.randint(0, 4, (M, 1)).astype('float32')
@@ -2820,7 +2818,7 @@ def test_store_op(M, src_layout, device):
28202818
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
28212819
@pytest.mark.parametrize("src_dim", [0, 1])
28222820
@pytest.mark.parametrize("dst_dim", [0, 1])
2823-
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
2821+
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path: pathlib.Path):
28242822

28252823
ir = f"""
28262824
#dst = {dst_layout}
@@ -2840,10 +2838,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
28402838
}}
28412839
}}
28422840
"""
2843-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2844-
f.write(ir)
2845-
f.flush()
2846-
kernel = triton.compile(f.name)
2841+
temp_file = tmp_path / "test_convert1d.ttgir"
2842+
temp_file.write_text(ir)
2843+
kernel = triton.compile(str(temp_file))
28472844

28482845
rs = RandomState(17)
28492846
x = rs.randint(0, 4, (M, )).astype('int32')
@@ -2881,7 +2878,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
28812878
@pytest.mark.parametrize("src_layout", layouts)
28822879
@pytest.mark.parametrize("op", ["sum", "max"])
28832880
@pytest.mark.parametrize("first_axis", [0, 1])
2884-
def test_chain_reduce(M, N, src_layout, op, device, first_axis):
2881+
def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path):
28852882

28862883
op_str = ""
28872884
if op == "sum":
@@ -2922,10 +2919,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
29222919
}}
29232920
}}
29242921
"""
2925-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2926-
f.write(ir)
2927-
f.flush()
2928-
kernel = triton.compile(f.name)
2922+
temp_file = tmp_path / "test_chain_reduce.ttgir"
2923+
temp_file.write_text(ir)
2924+
kernel = triton.compile(str(temp_file))
29292925

29302926
rs = RandomState(17)
29312927
x = rs.randint(0, 4, (M, N)).astype('int32')
@@ -5241,7 +5237,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
52415237
@pytest.mark.parametrize("src_layout", layouts)
52425238
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
52435239
@pytest.mark.parametrize("dst_layout", layouts)
5244-
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
5240+
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
52455241
if str(src_layout) == str(dst_layout):
52465242
pytest.skip()
52475243
if is_hip():
@@ -5306,10 +5302,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53065302
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
53075303
z = torch.empty_like(x, device=device)
53085304

5309-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
5310-
f.write(ir)
5311-
f.flush()
5312-
kernel = triton.compile(f.name)
5305+
temp_file = tmp_path / "test_convert2d.ttgir"
5306+
temp_file.write_text(ir)
5307+
kernel = triton.compile(str(temp_file))
5308+
53135309
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
53145310

53155311
assert torch.equal(z, x)
@@ -5362,7 +5358,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53625358
@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
53635359
@pytest.mark.parametrize("dtype", ['float16'])
53645360
@pytest.mark.parametrize("mma_pair", mma_pairs)
5365-
def test_convertmma2mma(M, N, mma_pair, dtype, device):
5361+
def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
53665362
if is_hip():
53675363
pytest.skip("test_mma2mma is not supported in HIP")
53685364

@@ -5419,10 +5415,10 @@ def do_test(src_layout, dst_layout):
54195415
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
54205416
z = torch.empty_like(x)
54215417

5422-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
5423-
f.write(ir)
5424-
f.flush()
5425-
kernel = triton.compile(f.name)
5418+
temp_file = tmp_path / "test_convertmma2mma.ttgir"
5419+
temp_file.write_text(ir)
5420+
kernel = triton.compile(str(temp_file))
5421+
54265422
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
54275423

54285424
assert torch.equal(z, x)

python/test/unit/runtime/test_cache.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib.util
22
import itertools
33
import shutil
4-
import tempfile
4+
import pathlib
55

66
import pytest
77
import torch
@@ -129,28 +129,28 @@ def test_combine_fn_change():
129129
seen_keys.add(key)
130130

131131

132-
def write_and_load_module(code, num_extra_lines):
133-
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f:
134-
f.write(('# extra line\n' * num_extra_lines) + code)
135-
f.flush()
136-
spec = importlib.util.spec_from_file_location("module.name", f.name)
137-
module = importlib.util.module_from_spec(spec)
138-
spec.loader.exec_module(module)
132+
def write_and_load_module(temp_file: pathlib.Path, code, num_extra_lines):
133+
temp_file.write_text(('# extra line\n' * num_extra_lines) + code)
134+
spec = importlib.util.spec_from_file_location("module.name", str(temp_file))
135+
module = importlib.util.module_from_spec(spec)
136+
spec.loader.exec_module(module)
139137
return module
140138

141139

142-
def test_changed_line_numbers_invalidate_cache():
140+
def test_changed_line_numbers_invalidate_cache(tmp_path: pathlib.Path):
143141
from textwrap import dedent
144142
code = dedent("""
145143
import triton
146144
@triton.jit
147145
def test_kernel(i):
148146
i = i + 1
149147
""")
150-
orig_mod = write_and_load_module(code, 0)
148+
temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py"
149+
orig_mod = write_and_load_module(temp_file0, code, 0)
151150
orig_cache_key = orig_mod.test_kernel.cache_key
152151

153-
updated_mod = write_and_load_module(code, 1)
152+
temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py"
153+
updated_mod = write_and_load_module(temp_file1, code, 1)
154154
updated_cache_key = updated_mod.test_kernel.cache_key
155155
assert orig_cache_key != updated_cache_key
156156

0 commit comments

Comments
 (0)