@@ -6487,7 +6487,7 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
6487
6487
6488
6488
@pytest .mark .parametrize ("M, N, M_tile_size, N_tile_size" ,
6489
6489
[[128 , 128 , 64 , 64 ], [128 , 128 , 64 , 32 ], [128 , 64 , 64 , 32 ], [256 , 128 , 64 , 64 ]])
6490
- def test_split_subview (M , N , M_tile_size , N_tile_size , device ):
6490
+ def test_split_subview (M , N , M_tile_size , N_tile_size , device , tmp_path : pathlib . Path ):
6491
6491
num_rows_per_warp = THREADS_PER_WARP // 4
6492
6492
num_repeats_M = triton .cdiv (M , M_tile_size )
6493
6493
num_repeats_N = triton .cdiv (N , N_tile_size )
@@ -6546,11 +6546,9 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device):
6546
6546
}}
6547
6547
"""
6548
6548
6549
- import tempfile
6550
- with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.ttgir' ) as f :
6551
- f .write (ir )
6552
- f .flush ()
6553
- kernel = triton .compile (f .name )
6549
+ temp_file = tmp_path / "test_split_subview.ttgir"
6550
+ temp_file .write_text (ir )
6551
+ kernel = triton .compile (str (temp_file ))
6554
6552
6555
6553
triton_result = torch .zeros ((M , N ), device = device , dtype = torch .float16 )
6556
6554
kernel [(1 , 1 , 1 )](triton_result .data_ptr ())
0 commit comments