55from typing import Optional
66import math
77import textwrap
8- import tempfile
8+ import pathlib
99
1010import numpy as np
1111import pytest
@@ -2558,7 +2558,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25582558@pytest .mark .parametrize ("M, N" , [[32 , 16 ], [32 , 32 ], [32 , 64 ], [64 , 32 ]])
25592559@pytest .mark .parametrize ("src_layout" , scan_layouts )
25602560@pytest .mark .parametrize ("axis" , [0 , 1 ])
2561- def test_scan_layouts (M , N , src_layout , axis , device ):
2561+ def test_scan_layouts (M , N , src_layout , axis , device , tmp_path : pathlib . Path ):
25622562
25632563 ir = f"""
25642564 #blocked = { src_layout }
@@ -2591,10 +2591,10 @@ def test_scan_layouts(M, N, src_layout, axis, device):
25912591 }}
25922592 """
25932593
2594- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
2595- f . write (ir )
2596- f . flush ( )
2597- kernel = triton . compile ( f . name )
2594+ temp_file = tmp_path / "test_scan_layouts.ttgir"
2595+ temp_file . write_text (ir )
2596+ kernel = triton . compile ( str ( temp_file ) )
2597+
25982598 rs = RandomState (17 )
25992599 x = rs .randint (- 100 , 100 , (M , N )).astype ('int32' )
26002600
@@ -2642,7 +2642,7 @@ def test_scan_layouts(M, N, src_layout, axis, device):
26422642@pytest .mark .parametrize ("epilogue_kind" , ['reduce1d' , 'reduce2d' , 'expand_reduce2d' ])
26432643@pytest .mark .parametrize ("dtype_str" , ["int32" , "float32" , "float16" ])
26442644@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
2645- def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device ):
2645+ def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device , tmp_path : pathlib . Path ):
26462646 if isinstance (src_layout ,
26472647 (MfmaLayout , MmaLayout )) and (M < src_layout .instr_shape [0 ] or N < src_layout .instr_shape [1 ]):
26482648 pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
@@ -2736,10 +2736,9 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27362736 }}) {{axis = { axis } : i32}} : (tensor<{ M } x{ N } x{ ty } , #src>) -> tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>
27372737 """ + epilogue
27382738
2739- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2740- f .write (ir )
2741- f .flush ()
2742- kernel = triton .compile (f .name )
2739+ temp_file = tmp_path / "test_reduce_layouts.ttgir"
2740+ temp_file .write_text (ir )
2741+ kernel = triton .compile (str (temp_file ))
27432742
27442743 rs = RandomState (17 )
27452744 x = numpy_random ((M , N ), dtype_str = dtype_str , rs = rs , low = 0 , high = 10 )
@@ -2769,7 +2768,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27692768
27702769@pytest .mark .parametrize ("M" , [32 , 64 , 128 , 256 ])
27712770@pytest .mark .parametrize ("src_layout" , layouts )
2772- def test_store_op (M , src_layout , device ):
2771+ def test_store_op (M , src_layout , device , tmp_path : pathlib . Path ):
27732772
27742773 ir = f"""
27752774 #src = { src_layout }
@@ -2790,10 +2789,9 @@ def test_store_op(M, src_layout, device):
27902789 }}
27912790 """
27922791
2793- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2794- f .write (ir )
2795- f .flush ()
2796- store_kernel = triton .compile (f .name )
2792+ temp_file = tmp_path / "test_store_op.ttgir"
2793+ temp_file .write_text (ir )
2794+ store_kernel = triton .compile (str (temp_file ))
27972795
27982796 rs = RandomState (17 )
27992797 x = rs .randint (0 , 4 , (M , 1 )).astype ('float32' )
@@ -2820,7 +2818,7 @@ def test_store_op(M, src_layout, device):
28202818@pytest .mark .parametrize ("dst_layout" , filter_layouts (layouts ))
28212819@pytest .mark .parametrize ("src_dim" , [0 , 1 ])
28222820@pytest .mark .parametrize ("dst_dim" , [0 , 1 ])
2823- def test_convert1d (M , src_layout , dst_layout , src_dim , dst_dim , device ):
2821+ def test_convert1d (M , src_layout , dst_layout , src_dim , dst_dim , device , tmp_path : pathlib . Path ):
28242822
28252823 ir = f"""
28262824 #dst = { dst_layout }
@@ -2840,10 +2838,9 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
28402838 }}
28412839 }}
28422840 """
2843- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2844- f .write (ir )
2845- f .flush ()
2846- kernel = triton .compile (f .name )
2841+ temp_file = tmp_path / "test_convert1d.ttgir"
2842+ temp_file .write_text (ir )
2843+ kernel = triton .compile (str (temp_file ))
28472844
28482845 rs = RandomState (17 )
28492846 x = rs .randint (0 , 4 , (M , )).astype ('int32' )
@@ -2881,7 +2878,7 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
28812878@pytest .mark .parametrize ("src_layout" , layouts )
28822879@pytest .mark .parametrize ("op" , ["sum" , "max" ])
28832880@pytest .mark .parametrize ("first_axis" , [0 , 1 ])
2884- def test_chain_reduce (M , N , src_layout , op , device , first_axis ):
2881+ def test_chain_reduce (M , N , src_layout , op , device , first_axis , tmp_path : pathlib . Path ):
28852882
28862883 op_str = ""
28872884 if op == "sum" :
@@ -2922,10 +2919,9 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
29222919 }}
29232920 }}
29242921 """
2925- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
2926- f .write (ir )
2927- f .flush ()
2928- kernel = triton .compile (f .name )
2922+ temp_file = tmp_path / "test_chain_reduce.ttgir"
2923+ temp_file .write_text (ir )
2924+ kernel = triton .compile (str (temp_file ))
29292925
29302926 rs = RandomState (17 )
29312927 x = rs .randint (0 , 4 , (M , N )).astype ('int32' )
@@ -5241,7 +5237,7 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
52415237@pytest .mark .parametrize ("src_layout" , layouts )
52425238@pytest .mark .parametrize ("interm_layout" , intermediate_layouts )
52435239@pytest .mark .parametrize ("dst_layout" , layouts )
5244- def test_convert2d (M , N , src_layout , interm_layout , dst_layout , dtype , device ):
5240+ def test_convert2d (M , N , src_layout , interm_layout , dst_layout , dtype , device , tmp_path : pathlib . Path ):
52455241 if str (src_layout ) == str (dst_layout ):
52465242 pytest .skip ()
52475243 if is_hip ():
@@ -5306,10 +5302,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53065302 x = to_triton (numpy_random ((M , N ), dtype_str = dtype ), device = device )
53075303 z = torch .empty_like (x , device = device )
53085304
5309- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
5310- f . write (ir )
5311- f . flush ( )
5312- kernel = triton . compile ( f . name )
5305+ temp_file = tmp_path / "test_convert2d.ttgir"
5306+ temp_file . write_text (ir )
5307+ kernel = triton . compile ( str ( temp_file ) )
5308+
53135309 kernel [(1 , 1 , 1 )](x .data_ptr (), z .data_ptr ())
53145310
53155311 assert torch .equal (z , x )
@@ -5362,7 +5358,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
53625358@pytest .mark .parametrize ("M, N" , [[64 , 1 ], [1 , 64 ], [64 , 64 ], [128 , 128 ], [256 , 256 ]])
53635359@pytest .mark .parametrize ("dtype" , ['float16' ])
53645360@pytest .mark .parametrize ("mma_pair" , mma_pairs )
5365- def test_convertmma2mma (M , N , mma_pair , dtype , device ):
5361+ def test_convertmma2mma (M , N , mma_pair , dtype , device , tmp_path : pathlib . Path ):
53665362 if is_hip ():
53675363 pytest .skip ("test_mma2mma is not supported in HIP" )
53685364
@@ -5419,10 +5415,10 @@ def do_test(src_layout, dst_layout):
54195415 x = to_triton (numpy_random ((M , N ), dtype_str = dtype ), device = device )
54205416 z = torch .empty_like (x )
54215417
5422- with tempfile . NamedTemporaryFile ( mode = 'w' , suffix = '.ttgir' ) as f :
5423- f . write (ir )
5424- f . flush ( )
5425- kernel = triton . compile ( f . name )
5418+ temp_file = tmp_path / "test_convertmma2mma.ttgir"
5419+ temp_file . write_text (ir )
5420+ kernel = triton . compile ( str ( temp_file ) )
5421+
54265422 kernel [(1 , 1 , 1 )](x .data_ptr (), z .data_ptr ())
54275423
54285424 assert torch .equal (z , x )
0 commit comments