Skip to content

Commit 9d89a0a

Browse files
committed
Use pytest' tmp_path fixture instead of tempfile.NamedTemporaryFile
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b8fc4b9 commit 9d89a0a

File tree

6 files changed

+350
-348
lines changed

6 files changed

+350
-348
lines changed

python/test/unit/language/test_core.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66
import math
77
import textwrap
8-
import tempfile
98

109
import numpy as np
1110
import pytest
@@ -2589,7 +2588,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25892588
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
25902589
@pytest.mark.parametrize("src_layout", scan_layouts)
25912590
@pytest.mark.parametrize("axis", [0, 1])
2592-
def test_scan_layouts(M, N, src_layout, axis, device):
2591+
def test_scan_layouts(M, N, src_layout, axis, device, tmp_path):
25932592

25942593
ir = f"""
25952594
#blocked = {src_layout}
@@ -2622,10 +2621,10 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26222621
}}
26232622
"""
26242623

2625-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2626-
f.write(ir)
2627-
f.flush()
2628-
kernel = triton.compile(f.name)
2624+
temp_file = tmp_path / "test_scan_layouts.ttgir"
2625+
temp_file.write_text(ir)
2626+
kernel = triton.compile(str(temp_file))
2627+
26292628
rs = RandomState(17)
26302629
x = rs.randint(-100, 100, (M, N)).astype('int32')
26312630

@@ -2662,7 +2661,7 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26622661
@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d'])
26632662
@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"])
26642663
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
2665-
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device):
2664+
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device, tmp_path):
26662665
if isinstance(src_layout,
26672666
(MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]):
26682667
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
@@ -2756,10 +2755,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27562755
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
27572756
""" + epilogue
27582757

2759-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2760-
f.write(ir)
2761-
f.flush()
2762-
kernel = triton.compile(f.name)
2758+
temp_file = tmp_path / "test_reduce_layouts.ttgir"
2759+
temp_file.write_text(ir)
2760+
kernel = triton.compile(str(temp_file))
27632761

27642762
rs = RandomState(17)
27652763
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
@@ -2789,7 +2787,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27892787

27902788
@pytest.mark.parametrize("M", [32, 64, 128, 256])
27912789
@pytest.mark.parametrize("src_layout", layouts)
2792-
def test_store_op(M, src_layout, device):
2790+
def test_store_op(M, src_layout, device, tmp_path):
27932791

27942792
ir = f"""
27952793
#src = {src_layout}
@@ -2810,10 +2808,9 @@ def test_store_op(M, src_layout, device):
28102808
}}
28112809
"""
28122810

2813-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2814-
f.write(ir)
2815-
f.flush()
2816-
store_kernel = triton.compile(f.name)
2811+
temp_file = tmp_path / "test_store_op.ttgir"
2812+
temp_file.write_text(ir)
2813+
store_kernel = triton.compile(str(temp_file))
28172814

28182815
rs = RandomState(17)
28192816
x = rs.randint(0, 4, (M, 1)).astype('float32')
@@ -2840,7 +2837,7 @@ def test_store_op(M, src_layout, device):
28402837
@pytest.mark.parametrize("dst_layout", filter_layouts(layouts))
28412838
@pytest.mark.parametrize("src_dim", [0, 1])
28422839
@pytest.mark.parametrize("dst_dim", [0, 1])
2843-
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
2840+
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp_path):
28442841

28452842
ir = f"""
28462843
#dst = {dst_layout}
@@ -2860,10 +2857,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
28602857
}}
28612858
}}
28622859
"""
2863-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2864-
f.write(ir)
2865-
f.flush()
2866-
kernel = triton.compile(f.name)
2860+
temp_file = tmp_path / "test_convert1d.ttgir"
2861+
temp_file.write_text(ir)
2862+
kernel = triton.compile(str(temp_file))
28672863

28682864
rs = RandomState(17)
28692865
x = rs.randint(0, 4, (M, )).astype('int32')
@@ -2901,7 +2897,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
29012897
@pytest.mark.parametrize("src_layout", layouts)
29022898
@pytest.mark.parametrize("op", ["sum", "max"])
29032899
@pytest.mark.parametrize("first_axis", [0, 1])
2904-
def test_chain_reduce(M, N, src_layout, op, device, first_axis):
2900+
def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path):
29052901

29062902
op_str = ""
29072903
if op == "sum":
@@ -2942,10 +2938,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
29422938
}}
29432939
}}
29442940
"""
2945-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
2946-
f.write(ir)
2947-
f.flush()
2948-
kernel = triton.compile(f.name)
2941+
temp_file = tmp_path / "test_chain_reduce.ttgir"
2942+
temp_file.write_text(ir)
2943+
kernel = triton.compile(str(temp_file))
29492944

29502945
rs = RandomState(17)
29512946
x = rs.randint(0, 4, (M, N)).astype('int32')
@@ -5260,7 +5255,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
52605255
@pytest.mark.parametrize("src_layout", layouts)
52615256
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
52625257
@pytest.mark.parametrize("dst_layout", layouts)
5263-
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
5258+
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path):
52645259
if str(src_layout) == str(dst_layout):
52655260
pytest.xfail("Do not convert same layout")
52665261
if is_hip() or is_xpu():
@@ -5329,10 +5324,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53295324
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
53305325
z = torch.empty_like(x, device=device)
53315326

5332-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
5333-
f.write(ir)
5334-
f.flush()
5335-
kernel = triton.compile(f.name)
5327+
temp_file = tmp_path / "test_convert2d.ttgir"
5328+
temp_file.write_text(ir)
5329+
kernel = triton.compile(str(temp_file))
5330+
53365331
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
53375332

53385333
assert torch.equal(z, x)
@@ -5385,7 +5380,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53855380
@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
53865381
@pytest.mark.parametrize("dtype", ['float16'])
53875382
@pytest.mark.parametrize("mma_pair", mma_pairs)
5388-
def test_convertmma2mma(M, N, mma_pair, dtype, device):
5383+
def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path):
53895384
if is_hip() or is_xpu():
53905385
pytest.xfail("test_mma2mma is not supported in HIP/XPU")
53915386

@@ -5442,10 +5437,10 @@ def do_test(src_layout, dst_layout):
54425437
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
54435438
z = torch.empty_like(x)
54445439

5445-
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
5446-
f.write(ir)
5447-
f.flush()
5448-
kernel = triton.compile(f.name)
5440+
temp_file = tmp_path / "test_convertmma2mma.ttgir"
5441+
temp_file.write_text(ir)
5442+
kernel = triton.compile(str(temp_file))
5443+
54495444
kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr())
54505445

54515446
assert torch.equal(z, x)

python/test/unit/runtime/test_cache.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import importlib.util
22
import itertools
33
import shutil
4-
import tempfile
54

65
import pytest
76
import torch
@@ -128,28 +127,28 @@ def test_combine_fn_change():
128127
seen_keys.add(key)
129128

130129

131-
def write_and_load_module(code, num_extra_lines):
132-
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f:
133-
f.write(('# extra line\n' * num_extra_lines) + code)
134-
f.flush()
135-
spec = importlib.util.spec_from_file_location("module.name", f.name)
136-
module = importlib.util.module_from_spec(spec)
137-
spec.loader.exec_module(module)
130+
def write_and_load_module(temp_file, code, num_extra_lines):
131+
temp_file.write_text(('# extra line\n' * num_extra_lines) + code)
132+
spec = importlib.util.spec_from_file_location("module.name", str(temp_file))
133+
module = importlib.util.module_from_spec(spec)
134+
spec.loader.exec_module(module)
138135
return module
139136

140137

141-
def test_changed_line_numbers_invalidate_cache():
138+
def test_changed_line_numbers_invalidate_cache(tmp_path):
142139
from textwrap import dedent
143140
code = dedent("""
144141
import triton
145142
@triton.jit
146143
def test_kernel(i):
147144
i = i + 1
148145
""")
149-
orig_mod = write_and_load_module(code, 0)
146+
temp_file0 = tmp_path / "test_changed_line_numbers_invalidate_cache0.py"
147+
orig_mod = write_and_load_module(temp_file0, code, 0)
150148
orig_cache_key = orig_mod.test_kernel.cache_key
151149

152-
updated_mod = write_and_load_module(code, 1)
150+
temp_file1 = tmp_path / "test_changed_line_numbers_invalidate_cache1.py"
151+
updated_mod = write_and_load_module(temp_file1, code, 1)
153152
updated_cache_key = updated_mod.test_kernel.cache_key
154153
assert orig_cache_key != updated_cache_key
155154

0 commit comments

Comments
 (0)