Skip to content

Commit 5ef8fc8

Browse files
committed
Add device pytest fixture for proton tests and use it
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent fe21682 commit 5ef8fc8

File tree

3 files changed

+45
-36
lines changed

3 files changed

+45
-36
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ jobs:
302302
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
303303
export TRITON_XPUPTI_LIB_PATH=$PTI_LIBS_DIR
304304
cd third_party/proton/test
305-
pytest test_api.py test_lib.py test_profile.py test_viewer.py -s -v
305+
pytest test_api.py test_lib.py test_profile.py test_viewer.py --device xpu -s -v
306306
cd ..
307307
308308
- name: Run minicore tests

third_party/proton/test/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import pytest
22

33

4+
def pytest_addoption(parser):
5+
parser.addoption("--device", action="store", default="cuda")
6+
7+
8+
@pytest.fixture
9+
def device(request):
10+
return request.config.getoption("--device")
11+
12+
413
@pytest.fixture
514
def fresh_knobs():
615
from triton._internal_testing import _fresh_knobs_impl

third_party/proton/test/test_profile.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def is_xpu():
2323

2424

2525
@pytest.mark.parametrize("context", ["shadow", "python"])
26-
def test_torch(context, tmp_path: pathlib.Path):
26+
def test_torch(context, tmp_path: pathlib.Path, device: str):
2727
temp_file = tmp_path / "test_torch.hatchet"
2828
proton.start(str(temp_file.with_suffix("")), context=context)
2929
proton.enter_scope("test")
3030
# F841 Local variable `temp` is assigned to but never used
31-
temp = torch.ones((2, 2), device="xpu") # noqa: F841
31+
temp = torch.ones((2, 2), device=device) # noqa: F841
3232
proton.exit_scope()
3333
proton.finalize()
3434
with temp_file.open() as f:
@@ -55,13 +55,13 @@ def test_torch(context, tmp_path: pathlib.Path):
5555
queue.append(child)
5656

5757

58-
def test_triton(tmp_path: pathlib.Path):
58+
def test_triton(tmp_path: pathlib.Path, device: str):
5959

6060
@triton.jit
6161
def foo(x, y):
6262
tl.store(y, tl.load(x))
6363

64-
x = torch.tensor([2], device="xpu")
64+
x = torch.tensor([2], device=device)
6565
y = torch.zeros_like(x)
6666
temp_file = tmp_path / "test_triton.hatchet"
6767
proton.start(str(temp_file.with_suffix("")))
@@ -80,7 +80,7 @@ def foo(x, y):
8080
assert data[0]["children"][1]["frame"]["name"] == "test2"
8181

8282

83-
def test_cudagraph(tmp_path: pathlib.Path):
83+
def test_cudagraph(tmp_path: pathlib.Path, device: str):
8484
if is_xpu():
8585
pytest.skip("xpu doesn't support cudagraph; FIXME: double check")
8686
stream = torch.cuda.Stream()
@@ -91,8 +91,8 @@ def foo(x, y, z):
9191
tl.store(z, tl.load(y) + tl.load(x))
9292

9393
def fn():
94-
a = torch.ones((2, 2), device="xpu")
95-
b = torch.ones((2, 2), device="xpu")
94+
a = torch.ones((2, 2), device=device)
95+
b = torch.ones((2, 2), device=device)
9696
c = a + b
9797
foo[(1, )](a, b, c)
9898

@@ -136,13 +136,13 @@ def fn():
136136
assert test_frame["children"][0]["metrics"]["time (ns)"] > 0
137137

138138

139-
def test_metrics(tmp_path: pathlib.Path):
139+
def test_metrics(tmp_path: pathlib.Path, device: str):
140140

141141
@triton.jit
142142
def foo(x, y):
143143
tl.store(y, tl.load(x))
144144

145-
x = torch.tensor([2], device="xpu")
145+
x = torch.tensor([2], device=device)
146146
y = torch.zeros_like(x)
147147
temp_file = tmp_path / "test_metrics.hatchet"
148148
proton.start(str(temp_file.with_suffix("")))
@@ -156,11 +156,11 @@ def foo(x, y):
156156
assert data[0]["children"][0]["metrics"]["foo"] == 1.0
157157

158158

159-
def test_scope_backward(tmp_path: pathlib.Path):
159+
def test_scope_backward(tmp_path: pathlib.Path, device: str):
160160
temp_file = tmp_path / "test_scope_backward.hatchet"
161161
proton.start(str(temp_file.with_suffix("")))
162162
with proton.scope("ones1"):
163-
a = torch.ones((100, 100), device="xpu", requires_grad=True)
163+
a = torch.ones((100, 100), device=device, requires_grad=True)
164164
with proton.scope("plus"):
165165
a2 = a * a * a
166166
with proton.scope("ones2"):
@@ -175,12 +175,12 @@ def test_scope_backward(tmp_path: pathlib.Path):
175175
assert len(data[0]["children"]) == 4
176176

177177

178-
def test_cpu_timed_scope(tmp_path: pathlib.Path):
178+
def test_cpu_timed_scope(tmp_path: pathlib.Path, device: str):
179179
temp_file = tmp_path / "test_cpu_timed_scope.hatchet"
180180
proton.start(str(temp_file.with_suffix("")))
181181
with proton.cpu_timed_scope("test0"):
182182
with proton.cpu_timed_scope("test1"):
183-
torch.ones((100, 100), device="xpu")
183+
torch.ones((100, 100), device=device)
184184
proton.finalize()
185185
with temp_file.open() as f:
186186
data = json.load(f)
@@ -193,7 +193,7 @@ def test_cpu_timed_scope(tmp_path: pathlib.Path):
193193
assert kernel_frame["metrics"]["time (ns)"] > 0
194194

195195

196-
def test_hook_launch(tmp_path: pathlib.Path):
196+
def test_hook_launch(tmp_path: pathlib.Path, device: str):
197197

198198
def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict):
199199
# get arg's element size
@@ -208,7 +208,7 @@ def foo(x, size: tl.constexpr, y):
208208
offs = tl.arange(0, size)
209209
tl.store(y + offs, tl.load(x + offs))
210210

211-
x = torch.tensor([2], device="xpu", dtype=torch.float32)
211+
x = torch.tensor([2], device=device, dtype=torch.float32)
212212
y = torch.zeros_like(x)
213213
temp_file = tmp_path / "test_hook_triton.hatchet"
214214
proton.start(str(temp_file.with_suffix("")), hook="triton")
@@ -225,7 +225,7 @@ def foo(x, size: tl.constexpr, y):
225225

226226

227227
@pytest.mark.parametrize("context", ["shadow", "python"])
228-
def test_hook_launch_context(tmp_path: pathlib.Path, context: str):
228+
def test_hook_launch_context(tmp_path: pathlib.Path, context: str, device: str):
229229

230230
def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict):
231231
x = args["x"]
@@ -237,7 +237,7 @@ def foo(x, size: tl.constexpr, y):
237237
offs = tl.arange(0, size)
238238
tl.store(y + offs, tl.load(x + offs))
239239

240-
x = torch.tensor([2], device="xpu", dtype=torch.float32)
240+
x = torch.tensor([2], device=device, dtype=torch.float32)
241241
y = torch.zeros_like(x)
242242
temp_file = tmp_path / "test_hook.hatchet"
243243
proton.start(str(temp_file.with_suffix("")), hook="triton", context=context)
@@ -257,7 +257,7 @@ def foo(x, size: tl.constexpr, y):
257257
queue.append(child)
258258

259259

260-
def test_hook_with_third_party(tmp_path: pathlib.Path):
260+
def test_hook_with_third_party(tmp_path: pathlib.Path, device: str):
261261
third_party_hook_invoked = False
262262

263263
def third_party_hook(metadata) -> None:
@@ -278,7 +278,7 @@ def foo(x, size: tl.constexpr, y):
278278
offs = tl.arange(0, size)
279279
tl.store(y + offs, tl.load(x + offs))
280280

281-
x = torch.tensor([2], device="xpu", dtype=torch.float32)
281+
x = torch.tensor([2], device=device, dtype=torch.float32)
282282
y = torch.zeros_like(x)
283283
temp_file = tmp_path / "test_hook_with_third_party.hatchet"
284284
proton.start(str(temp_file.with_suffix("")), hook="triton")
@@ -292,7 +292,7 @@ def foo(x, size: tl.constexpr, y):
292292
assert data[0]["children"][0]["metrics"]["time (ns)"] > 0
293293

294294

295-
def test_hook_multiple_threads(tmp_path: pathlib.Path):
295+
def test_hook_multiple_threads(tmp_path: pathlib.Path, device: str):
296296

297297
def metadata_fn_foo(grid: tuple, metadata: NamedTuple, args: dict):
298298
return {"name": "foo_test"}
@@ -310,9 +310,9 @@ def bar(x, size: tl.constexpr, y):
310310
offs = tl.arange(0, size)
311311
tl.store(y + offs, tl.load(x + offs))
312312

313-
x_foo = torch.tensor([2], device="xpu", dtype=torch.float32)
313+
x_foo = torch.tensor([2], device=device, dtype=torch.float32)
314314
y_foo = torch.zeros_like(x_foo)
315-
x_bar = torch.tensor([2], device="xpu", dtype=torch.float32)
315+
x_bar = torch.tensor([2], device=device, dtype=torch.float32)
316316
y_bar = torch.zeros_like(x_bar)
317317

318318
temp_file = tmp_path / "test_hook.hatchet"
@@ -350,7 +350,7 @@ def invoke_bar():
350350
assert root[1]["metrics"]["count"] == 100
351351

352352

353-
def test_pcsampling(tmp_path: pathlib.Path):
353+
def test_pcsampling(tmp_path: pathlib.Path, device: str):
354354
if is_hip():
355355
pytest.skip("HIP backend does not support pc sampling")
356356
if is_xpu():
@@ -370,7 +370,7 @@ def foo(x, y, size: tl.constexpr):
370370
temp_file = tmp_path / "test_pcsampling.hatchet"
371371
proton.start(str(temp_file.with_suffix("")), hook="triton", backend="cupti", mode="pcsampling")
372372
with proton.scope("init"):
373-
x = torch.ones((1024, ), device="xpu", dtype=torch.float32)
373+
x = torch.ones((1024, ), device=device, dtype=torch.float32)
374374
y = torch.zeros_like(x)
375375
with proton.scope("test"):
376376
foo[(1, )](x, y, x.size()[0], num_warps=4)
@@ -388,13 +388,13 @@ def foo(x, y, size: tl.constexpr):
388388
assert init_frame["children"][0]["metrics"]["num_samples"] > 0
389389

390390

391-
def test_deactivate(tmp_path: pathlib.Path):
391+
def test_deactivate(tmp_path: pathlib.Path, device: str):
392392
temp_file = tmp_path / "test_deactivate.hatchet"
393393
session_id = proton.start(str(temp_file.with_suffix("")), hook="triton")
394394
proton.deactivate(session_id)
395-
torch.randn((10, 10), device="xpu")
395+
torch.randn((10, 10), device=device)
396396
proton.activate(session_id)
397-
torch.zeros((10, 10), device="xpu")
397+
torch.zeros((10, 10), device=device)
398398
proton.deactivate(session_id)
399399
proton.finalize()
400400
with temp_file.open() as f:
@@ -405,18 +405,18 @@ def test_deactivate(tmp_path: pathlib.Path):
405405
assert "device_id" in data[0]["children"][0]["metrics"]
406406

407407

408-
def test_multiple_sessions(tmp_path: pathlib.Path):
408+
def test_multiple_sessions(tmp_path: pathlib.Path, device: str):
409409
temp_file0 = tmp_path / "test_multiple_sessions0.hatchet"
410410
temp_file1 = tmp_path / "test_multiple_sessions1.hatchet"
411411
session_id0 = proton.start(str(temp_file0.with_suffix("")))
412412
session_id1 = proton.start(str(temp_file1.with_suffix("")))
413413
with proton.scope("scope0"):
414-
torch.randn((10, 10), device="xpu")
415-
torch.randn((10, 10), device="xpu")
414+
torch.randn((10, 10), device=device)
415+
torch.randn((10, 10), device=device)
416416
proton.deactivate(session_id0)
417417
proton.finalize(session_id0)
418418
with proton.scope("scope1"):
419-
torch.randn((10, 10), device="xpu")
419+
torch.randn((10, 10), device=device)
420420
proton.finalize(session_id1)
421421
# kernel has been invoked twice in session 0 and three times in session 1
422422
with temp_file0.open() as f:
@@ -430,7 +430,7 @@ def test_multiple_sessions(tmp_path: pathlib.Path):
430430
assert scope0_count + scope1_count == 3
431431

432432

433-
def test_trace(tmp_path: pathlib.Path):
433+
def test_trace(tmp_path: pathlib.Path, device: str):
434434
temp_file = tmp_path / "test_trace.chrome_trace"
435435
proton.start(str(temp_file.with_suffix("")), data="trace")
436436

@@ -440,7 +440,7 @@ def foo(x, y, size: tl.constexpr):
440440
tl.store(y + offs, tl.load(x + offs))
441441

442442
with proton.scope("init"):
443-
x = torch.ones((1024, ), device="xpu", dtype=torch.float32)
443+
x = torch.ones((1024, ), device=device, dtype=torch.float32)
444444
y = torch.zeros_like(x)
445445

446446
with proton.scope("test"):
@@ -456,7 +456,7 @@ def foo(x, y, size: tl.constexpr):
456456
assert trace_events[-1]["args"]["call_stack"] == ["ROOT", "test", "foo"]
457457

458458

459-
def test_scope_multiple_threads(tmp_path: pathlib.Path):
459+
def test_scope_multiple_threads(tmp_path: pathlib.Path, device: str):
460460
temp_file = tmp_path / "test_scope_threads.hatchet"
461461
proton.start(str(temp_file.with_suffix("")))
462462

@@ -467,7 +467,7 @@ def worker(prefix: str):
467467
for i in range(N):
468468
name = f"{prefix}_{i}"
469469
proton.enter_scope(name)
470-
torch.ones((1, ), device="xpu")
470+
torch.ones((1, ), device=device)
471471
proton.exit_scope()
472472

473473
threads = [threading.Thread(target=worker, args=(tname, )) for tname in thread_names]

0 commit comments

Comments
 (0)