@@ -23,12 +23,12 @@ def is_xpu():
2323
2424
2525@pytest .mark .parametrize ("context" , ["shadow" , "python" ])
26- def test_torch (context , tmp_path : pathlib .Path ):
26+ def test_torch (context , tmp_path : pathlib .Path , device : str ):
2727 temp_file = tmp_path / "test_torch.hatchet"
2828 proton .start (str (temp_file .with_suffix ("" )), context = context )
2929 proton .enter_scope ("test" )
3030 # F841 Local variable `temp` is assigned to but never used
31- temp = torch .ones ((2 , 2 ), device = "xpu" ) # noqa: F841
31+ temp = torch .ones ((2 , 2 ), device = device ) # noqa: F841
3232 proton .exit_scope ()
3333 proton .finalize ()
3434 with temp_file .open () as f :
@@ -55,13 +55,13 @@ def test_torch(context, tmp_path: pathlib.Path):
5555 queue .append (child )
5656
5757
58- def test_triton (tmp_path : pathlib .Path ):
58+ def test_triton (tmp_path : pathlib .Path , device : str ):
5959
6060 @triton .jit
6161 def foo (x , y ):
6262 tl .store (y , tl .load (x ))
6363
64- x = torch .tensor ([2 ], device = "xpu" )
64+ x = torch .tensor ([2 ], device = device )
6565 y = torch .zeros_like (x )
6666 temp_file = tmp_path / "test_triton.hatchet"
6767 proton .start (str (temp_file .with_suffix ("" )))
@@ -80,7 +80,7 @@ def foo(x, y):
8080 assert data [0 ]["children" ][1 ]["frame" ]["name" ] == "test2"
8181
8282
83- def test_cudagraph (tmp_path : pathlib .Path ):
83+ def test_cudagraph (tmp_path : pathlib .Path , device : str ):
8484 if is_xpu ():
8585 pytest .skip ("xpu doesn't support cudagraph; FIXME: double check" )
8686 stream = torch .cuda .Stream ()
@@ -91,8 +91,8 @@ def foo(x, y, z):
9191 tl .store (z , tl .load (y ) + tl .load (x ))
9292
9393 def fn ():
94- a = torch .ones ((2 , 2 ), device = "xpu" )
95- b = torch .ones ((2 , 2 ), device = "xpu" )
94+ a = torch .ones ((2 , 2 ), device = device )
95+ b = torch .ones ((2 , 2 ), device = device )
9696 c = a + b
9797 foo [(1 , )](a , b , c )
9898
@@ -136,13 +136,13 @@ def fn():
136136 assert test_frame ["children" ][0 ]["metrics" ]["time (ns)" ] > 0
137137
138138
139- def test_metrics (tmp_path : pathlib .Path ):
139+ def test_metrics (tmp_path : pathlib .Path , device : str ):
140140
141141 @triton .jit
142142 def foo (x , y ):
143143 tl .store (y , tl .load (x ))
144144
145- x = torch .tensor ([2 ], device = "xpu" )
145+ x = torch .tensor ([2 ], device = device )
146146 y = torch .zeros_like (x )
147147 temp_file = tmp_path / "test_metrics.hatchet"
148148 proton .start (str (temp_file .with_suffix ("" )))
@@ -156,11 +156,11 @@ def foo(x, y):
156156 assert data [0 ]["children" ][0 ]["metrics" ]["foo" ] == 1.0
157157
158158
159- def test_scope_backward (tmp_path : pathlib .Path ):
159+ def test_scope_backward (tmp_path : pathlib .Path , device : str ):
160160 temp_file = tmp_path / "test_scope_backward.hatchet"
161161 proton .start (str (temp_file .with_suffix ("" )))
162162 with proton .scope ("ones1" ):
163- a = torch .ones ((100 , 100 ), device = "xpu" , requires_grad = True )
163+ a = torch .ones ((100 , 100 ), device = device , requires_grad = True )
164164 with proton .scope ("plus" ):
165165 a2 = a * a * a
166166 with proton .scope ("ones2" ):
@@ -175,12 +175,12 @@ def test_scope_backward(tmp_path: pathlib.Path):
175175 assert len (data [0 ]["children" ]) == 4
176176
177177
178- def test_cpu_timed_scope (tmp_path : pathlib .Path ):
178+ def test_cpu_timed_scope (tmp_path : pathlib .Path , device : str ):
179179 temp_file = tmp_path / "test_cpu_timed_scope.hatchet"
180180 proton .start (str (temp_file .with_suffix ("" )))
181181 with proton .cpu_timed_scope ("test0" ):
182182 with proton .cpu_timed_scope ("test1" ):
183- torch .ones ((100 , 100 ), device = "xpu" )
183+ torch .ones ((100 , 100 ), device = device )
184184 proton .finalize ()
185185 with temp_file .open () as f :
186186 data = json .load (f )
@@ -193,7 +193,7 @@ def test_cpu_timed_scope(tmp_path: pathlib.Path):
193193 assert kernel_frame ["metrics" ]["time (ns)" ] > 0
194194
195195
196- def test_hook_launch (tmp_path : pathlib .Path ):
196+ def test_hook_launch (tmp_path : pathlib .Path , device : str ):
197197
198198 def metadata_fn (grid : tuple , metadata : NamedTuple , args : dict ):
199199 # get arg's element size
@@ -208,7 +208,7 @@ def foo(x, size: tl.constexpr, y):
208208 offs = tl .arange (0 , size )
209209 tl .store (y + offs , tl .load (x + offs ))
210210
211- x = torch .tensor ([2 ], device = "xpu" , dtype = torch .float32 )
211+ x = torch .tensor ([2 ], device = device , dtype = torch .float32 )
212212 y = torch .zeros_like (x )
213213 temp_file = tmp_path / "test_hook_triton.hatchet"
214214 proton .start (str (temp_file .with_suffix ("" )), hook = "triton" )
@@ -225,7 +225,7 @@ def foo(x, size: tl.constexpr, y):
225225
226226
227227@pytest .mark .parametrize ("context" , ["shadow" , "python" ])
228- def test_hook_launch_context (tmp_path : pathlib .Path , context : str ):
228+ def test_hook_launch_context (tmp_path : pathlib .Path , context : str , device : str ):
229229
230230 def metadata_fn (grid : tuple , metadata : NamedTuple , args : dict ):
231231 x = args ["x" ]
@@ -237,7 +237,7 @@ def foo(x, size: tl.constexpr, y):
237237 offs = tl .arange (0 , size )
238238 tl .store (y + offs , tl .load (x + offs ))
239239
240- x = torch .tensor ([2 ], device = "xpu" , dtype = torch .float32 )
240+ x = torch .tensor ([2 ], device = device , dtype = torch .float32 )
241241 y = torch .zeros_like (x )
242242 temp_file = tmp_path / "test_hook.hatchet"
243243 proton .start (str (temp_file .with_suffix ("" )), hook = "triton" , context = context )
@@ -257,7 +257,7 @@ def foo(x, size: tl.constexpr, y):
257257 queue .append (child )
258258
259259
260- def test_hook_with_third_party (tmp_path : pathlib .Path ):
260+ def test_hook_with_third_party (tmp_path : pathlib .Path , device : str ):
261261 third_party_hook_invoked = False
262262
263263 def third_party_hook (metadata ) -> None :
@@ -278,7 +278,7 @@ def foo(x, size: tl.constexpr, y):
278278 offs = tl .arange (0 , size )
279279 tl .store (y + offs , tl .load (x + offs ))
280280
281- x = torch .tensor ([2 ], device = "xpu" , dtype = torch .float32 )
281+ x = torch .tensor ([2 ], device = device , dtype = torch .float32 )
282282 y = torch .zeros_like (x )
283283 temp_file = tmp_path / "test_hook_with_third_party.hatchet"
284284 proton .start (str (temp_file .with_suffix ("" )), hook = "triton" )
@@ -292,7 +292,7 @@ def foo(x, size: tl.constexpr, y):
292292 assert data [0 ]["children" ][0 ]["metrics" ]["time (ns)" ] > 0
293293
294294
295- def test_hook_multiple_threads (tmp_path : pathlib .Path ):
295+ def test_hook_multiple_threads (tmp_path : pathlib .Path , device : str ):
296296
297297 def metadata_fn_foo (grid : tuple , metadata : NamedTuple , args : dict ):
298298 return {"name" : "foo_test" }
@@ -310,9 +310,9 @@ def bar(x, size: tl.constexpr, y):
310310 offs = tl .arange (0 , size )
311311 tl .store (y + offs , tl .load (x + offs ))
312312
313- x_foo = torch .tensor ([2 ], device = "xpu" , dtype = torch .float32 )
313+ x_foo = torch .tensor ([2 ], device = device , dtype = torch .float32 )
314314 y_foo = torch .zeros_like (x_foo )
315- x_bar = torch .tensor ([2 ], device = "xpu" , dtype = torch .float32 )
315+ x_bar = torch .tensor ([2 ], device = device , dtype = torch .float32 )
316316 y_bar = torch .zeros_like (x_bar )
317317
318318 temp_file = tmp_path / "test_hook.hatchet"
@@ -350,7 +350,7 @@ def invoke_bar():
350350 assert root [1 ]["metrics" ]["count" ] == 100
351351
352352
353- def test_pcsampling (tmp_path : pathlib .Path ):
353+ def test_pcsampling (tmp_path : pathlib .Path , device : str ):
354354 if is_hip ():
355355 pytest .skip ("HIP backend does not support pc sampling" )
356356 if is_xpu ():
@@ -370,7 +370,7 @@ def foo(x, y, size: tl.constexpr):
370370 temp_file = tmp_path / "test_pcsampling.hatchet"
371371 proton .start (str (temp_file .with_suffix ("" )), hook = "triton" , backend = "cupti" , mode = "pcsampling" )
372372 with proton .scope ("init" ):
373- x = torch .ones ((1024 , ), device = "xpu" , dtype = torch .float32 )
373+ x = torch .ones ((1024 , ), device = device , dtype = torch .float32 )
374374 y = torch .zeros_like (x )
375375 with proton .scope ("test" ):
376376 foo [(1 , )](x , y , x .size ()[0 ], num_warps = 4 )
@@ -388,13 +388,13 @@ def foo(x, y, size: tl.constexpr):
388388 assert init_frame ["children" ][0 ]["metrics" ]["num_samples" ] > 0
389389
390390
391- def test_deactivate (tmp_path : pathlib .Path ):
391+ def test_deactivate (tmp_path : pathlib .Path , device : str ):
392392 temp_file = tmp_path / "test_deactivate.hatchet"
393393 session_id = proton .start (str (temp_file .with_suffix ("" )), hook = "triton" )
394394 proton .deactivate (session_id )
395- torch .randn ((10 , 10 ), device = "xpu" )
395+ torch .randn ((10 , 10 ), device = device )
396396 proton .activate (session_id )
397- torch .zeros ((10 , 10 ), device = "xpu" )
397+ torch .zeros ((10 , 10 ), device = device )
398398 proton .deactivate (session_id )
399399 proton .finalize ()
400400 with temp_file .open () as f :
@@ -405,18 +405,18 @@ def test_deactivate(tmp_path: pathlib.Path):
405405 assert "device_id" in data [0 ]["children" ][0 ]["metrics" ]
406406
407407
408- def test_multiple_sessions (tmp_path : pathlib .Path ):
408+ def test_multiple_sessions (tmp_path : pathlib .Path , device : str ):
409409 temp_file0 = tmp_path / "test_multiple_sessions0.hatchet"
410410 temp_file1 = tmp_path / "test_multiple_sessions1.hatchet"
411411 session_id0 = proton .start (str (temp_file0 .with_suffix ("" )))
412412 session_id1 = proton .start (str (temp_file1 .with_suffix ("" )))
413413 with proton .scope ("scope0" ):
414- torch .randn ((10 , 10 ), device = "xpu" )
415- torch .randn ((10 , 10 ), device = "xpu" )
414+ torch .randn ((10 , 10 ), device = device )
415+ torch .randn ((10 , 10 ), device = device )
416416 proton .deactivate (session_id0 )
417417 proton .finalize (session_id0 )
418418 with proton .scope ("scope1" ):
419- torch .randn ((10 , 10 ), device = "xpu" )
419+ torch .randn ((10 , 10 ), device = device )
420420 proton .finalize (session_id1 )
421421 # kernel has been invoked twice in session 0 and three times in session 1
422422 with temp_file0 .open () as f :
@@ -430,7 +430,7 @@ def test_multiple_sessions(tmp_path: pathlib.Path):
430430 assert scope0_count + scope1_count == 3
431431
432432
433- def test_trace (tmp_path : pathlib .Path ):
433+ def test_trace (tmp_path : pathlib .Path , device : str ):
434434 temp_file = tmp_path / "test_trace.chrome_trace"
435435 proton .start (str (temp_file .with_suffix ("" )), data = "trace" )
436436
@@ -440,7 +440,7 @@ def foo(x, y, size: tl.constexpr):
440440 tl .store (y + offs , tl .load (x + offs ))
441441
442442 with proton .scope ("init" ):
443- x = torch .ones ((1024 , ), device = "xpu" , dtype = torch .float32 )
443+ x = torch .ones ((1024 , ), device = device , dtype = torch .float32 )
444444 y = torch .zeros_like (x )
445445
446446 with proton .scope ("test" ):
@@ -456,7 +456,7 @@ def foo(x, y, size: tl.constexpr):
456456 assert trace_events [- 1 ]["args" ]["call_stack" ] == ["ROOT" , "test" , "foo" ]
457457
458458
459- def test_scope_multiple_threads (tmp_path : pathlib .Path ):
459+ def test_scope_multiple_threads (tmp_path : pathlib .Path , device : str ):
460460 temp_file = tmp_path / "test_scope_threads.hatchet"
461461 proton .start (str (temp_file .with_suffix ("" )))
462462
@@ -467,7 +467,7 @@ def worker(prefix: str):
467467 for i in range (N ):
468468 name = f"{ prefix } _{ i } "
469469 proton .enter_scope (name )
470- torch .ones ((1 , ), device = "xpu" )
470+ torch .ones ((1 , ), device = device )
471471 proton .exit_scope ()
472472
473473 threads = [threading .Thread (target = worker , args = (tname , )) for tname in thread_names ]
0 commit comments