55from typing import Optional
66import math
77import textwrap
8- import tempfile
98
109import numpy as np
1110import pytest
@@ -2589,7 +2588,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25892588@pytest .mark .parametrize ("M, N" , [[32 , 16 ], [32 , 32 ], [32 , 64 ], [64 , 32 ]])
25902589@pytest .mark .parametrize ("src_layout" , scan_layouts )
25912590@pytest .mark .parametrize ("axis" , [0 , 1 ])
2592- def test_scan_layouts (M , N , src_layout , axis , device ):
2591+ def test_scan_layouts (M , N , src_layout , axis , device , tmp_path ):
25932592
25942593 ir = f"""
25952594 #blocked = { src_layout }
@@ -2622,10 +2621,10 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26222621 }}
26232622 """
26242623
2625- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
2626- f . write (ir )
2627- f . flush ( )
2628- kernel = triton . compile ( f . name )
2624+ temp_file = tmp_path / "test_scan_layouts.ttgir"
2625+ temp_file . write_text (ir )
2626+ kernel = triton . compile ( str ( temp_file ) )
2627+
26292628 rs = RandomState (17 )
26302629 x = rs .randint (- 100 , 100 , (M , N )).astype ('int32' )
26312630
@@ -2662,7 +2661,7 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26622661@pytest .mark .parametrize ("epilogue_kind" , ['reduce1d' , 'reduce2d' , 'expand_reduce2d' ])
26632662@pytest .mark .parametrize ("dtype_str" , ["int32" , "float32" , "float16" ])
26642663@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
2665- def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device ):
2664+ def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device , tmp_path ):
26662665 if isinstance (src_layout ,
26672666 (MfmaLayout , MmaLayout )) and (M < src_layout .instr_shape [0 ] or N < src_layout .instr_shape [1 ]):
26682667 pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
@@ -2756,10 +2755,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27562755 }}) {{axis = { axis } : i32}} : (tensor<{ M } x{ N } x{ ty } , #src>) -> tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>
27572756 """ + epilogue
27582757
2759- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2760- f .write (ir )
2761- f .flush ()
2762- kernel = triton .compile (f .name )
2758+ temp_file = tmp_path / "test_reduce_layouts.ttgir"
2759+ temp_file .write_text (ir )
2760+ kernel = triton .compile (str (temp_file ))
27632761
27642762 rs = RandomState (17 )
27652763 x = numpy_random ((M , N ), dtype_str = dtype_str , rs = rs , low = 0 , high = 10 )
@@ -2789,7 +2787,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27892787
27902788@pytest .mark .parametrize ("M" , [32 , 64 , 128 , 256 ])
27912789@pytest .mark .parametrize ("src_layout" , layouts )
2792- def test_store_op (M , src_layout , device ):
2790+ def test_store_op (M , src_layout , device , tmp_path ):
27932791
27942792 ir = f"""
27952793 #src = { src_layout }
@@ -2810,10 +2808,9 @@ def test_store_op(M, src_layout, device):
28102808 }}
28112809 """
28122810
2813- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2814- f .write (ir )
2815- f .flush ()
2816- store_kernel = triton .compile (f .name )
2811+ temp_file = tmp_path / "test_store_op.ttgir"
2812+ temp_file .write_text (ir )
2813+ store_kernel = triton .compile (str (temp_file ))
28172814
28182815 rs = RandomState (17 )
28192816 x = rs .randint (0 , 4 , (M , 1 )).astype ('float32' )
@@ -2840,7 +2837,7 @@ def test_store_op(M, src_layout, device):
28402837@pytest .mark .parametrize ("dst_layout" , filter_layouts (layouts ))
28412838@pytest .mark .parametrize ("src_dim" , [0 , 1 ])
28422839@pytest .mark .parametrize ("dst_dim" , [0 , 1 ])
2843- def test_convert1d (M , src_layout , dst_layout , src_dim , dst_dim , device ):
2840+ def test_convert1d (M , src_layout , dst_layout , src_dim , dst_dim , device , tmp_path ):
28442841
28452842 ir = f"""
28462843 #dst = { dst_layout }
@@ -2860,10 +2857,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
28602857 }}
28612858 }}
28622859 """
2863- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2864- f .write (ir )
2865- f .flush ()
2866- kernel = triton .compile (f .name )
2860+ temp_file = tmp_path / "test_convert1d.ttgir"
2861+ temp_file .write_text (ir )
2862+ kernel = triton .compile (str (temp_file ))
28672863
28682864 rs = RandomState (17 )
28692865 x = rs .randint (0 , 4 , (M , )).astype ('int32' )
@@ -2901,7 +2897,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
29012897@pytest .mark .parametrize ("src_layout" , layouts )
29022898@pytest .mark .parametrize ("op" , ["sum" , "max" ])
29032899@pytest .mark .parametrize ("first_axis" , [0 , 1 ])
2904- def test_chain_reduce (M , N , src_layout , op , device , first_axis ):
2900+ def test_chain_reduce (M , N , src_layout , op , device , first_axis , tmp_path ):
29052901
29062902 op_str = ""
29072903 if op == "sum" :
@@ -2942,10 +2938,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
29422938 }}
29432939 }}
29442940 """
2945- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2946- f .write (ir )
2947- f .flush ()
2948- kernel = triton .compile (f .name )
2941+ temp_file = tmp_path / "test_chain_reduce.ttgir"
2942+ temp_file .write_text (ir )
2943+ kernel = triton .compile (str (temp_file ))
29492944
29502945 rs = RandomState (17 )
29512946 x = rs .randint (0 , 4 , (M , N )).astype ('int32' )
@@ -5260,7 +5255,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
52605255@pytest .mark .parametrize ("src_layout" , layouts )
52615256@pytest .mark .parametrize ("interm_layout" , intermediate_layouts )
52625257@pytest .mark .parametrize ("dst_layout" , layouts )
5263- def test_convert2d (M , N , src_layout , interm_layout , dst_layout , dtype , device ):
5258+ def test_convert2d (M , N , src_layout , interm_layout , dst_layout , dtype , device , tmp_path ):
52645259 if str (src_layout ) == str (dst_layout ):
52655260 pytest .xfail ("Do not convert same layout" )
52665261 if is_hip () or is_xpu ():
@@ -5329,10 +5324,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53295324 x = to_triton (numpy_random ((M , N ), dtype_str = dtype ), device = device )
53305325 z = torch .empty_like (x , device = device )
53315326
5332- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
5333- f . write (ir )
5334- f . flush ( )
5335- kernel = triton . compile ( f . name )
5327+ temp_file = tmp_path / "test_convert2d.ttgir"
5328+ temp_file . write_text (ir )
5329+ kernel = triton . compile ( str ( temp_file ) )
5330+
53365331 kernel [(1 , 1 , 1 )](x .data_ptr (), z .data_ptr ())
53375332
53385333 assert torch .equal (z , x )
@@ -5385,7 +5380,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53855380@pytest .mark .parametrize ("M, N" , [[64 , 1 ], [1 , 64 ], [64 , 64 ], [128 , 128 ], [256 , 256 ]])
53865381@pytest .mark .parametrize ("dtype" , ['float16' ])
53875382@pytest .mark .parametrize ("mma_pair" , mma_pairs )
5388- def test_convertmma2mma (M , N , mma_pair , dtype , device ):
5383+ def test_convertmma2mma (M , N , mma_pair , dtype , device , tmp_path ):
53895384 if is_hip () or is_xpu ():
53905385 pytest .xfail ("test_mma2mma is not supported in HIP/XPU" )
53915386
@@ -5442,10 +5437,10 @@ def do_test(src_layout, dst_layout):
54425437 x = to_triton (numpy_random ((M , N ), dtype_str = dtype ), device = device )
54435438 z = torch .empty_like (x )
54445439
5445- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
5446- f . write (ir )
5447- f . flush ( )
5448- kernel = triton . compile ( f . name )
5440+ temp_file = tmp_path / "test_convertmma2mma.ttgir"
5441+ temp_file . write_text (ir )
5442+ kernel = triton . compile ( str ( temp_file ) )
5443+
54495444 kernel [(1 , 1 , 1 )](x .data_ptr (), z .data_ptr ())
54505445
54515446 assert torch .equal (z , x )
0 commit comments