Skip to content

Commit d7b915c

Browse files
pchen7e2meta-codesync[bot]
authored andcommitted
[TLX] Exclude AMD tutorial kernel test if not on AMD GPU (#635)
Summary: - Let pytest just grab and test all things under a folder directly for dense output - Skip AMD test if not on AMD GPU `third_party/tlx/run_all.sh` now skips `third_party/tlx/tutorials/amd-gemm-pipelined.py` on NV GPU as tested locally ``` % third_party/tlx/run_all.sh Hello! (Facebook-only) Need to build triton in this script? {y|n}n Run all LITs? {y|n}n Run core Triton python unit tests? {y|n}n Run all TLX unit tests? {y|n}n Run TLX tutorial kernels (correctness|performance|no)? {c|p|n} c Verifying correctness of TLX tutorial kernels ============================================================================================ test session starts ============================================================================================ platform linux -- Python 3.11.13, pytest-8.3.4, pluggy-1.5.0 rootdir: /data/users/pchen7e4/triton configfile: pyproject.toml plugins: xdist-3.7.0, forked-1.6.0, typeguard-4.3.0 collected 17 items third_party/tlx/tutorials/amd-gemm-pipelined.py s [ 5%] third_party/tlx/tutorials/blackwell-fa-ws-persistent_test.py . [ 11%] third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py . [ 17%] third_party/tlx/tutorials/blackwell-fa-ws-pipelined_test.py . [ 23%] third_party/tlx/tutorials/blackwell-fa-ws_test.py . [ 29%] third_party/tlx/tutorials/blackwell-gemm-clc.py . [ 35%] third_party/tlx/tutorials/blackwell-gemm-pipelined.py . [ 41%] third_party/tlx/tutorials/blackwell-gemm-ws.py . [ 47%] third_party/tlx/tutorials/blackwell-grouped-gemm.py . [ 52%] third_party/tlx/tutorials/hopper-fa-ws-pipelined-pingpong_test.py s [ 58%] third_party/tlx/tutorials/hopper-fa-ws-pipelined_test.py s [ 64%] third_party/tlx/tutorials/hopper-fa-ws_test.py s [ 70%] third_party/tlx/tutorials/hopper-gemm-pipelined_test.py s [ 76%] third_party/tlx/tutorials/hopper-gemm-ws_test.py s [ 82%] third_party/tlx/tutorials/hopper-persistent-gemm-ws-cooperative.py s [ 88%] third_party/tlx/tutorials/hopper-persistent-gemm-ws-pingpong.py s [ 94%] third_party/tlx/tutorials/vector-add2.py . [100%] ============================================================================================= warnings summary ============================================================================================== python/triton/runtime/autotuner.py:99 python/triton/runtime/autotuner.py:99 python/triton/runtime/autotuner.py:99 /data/users/pchen7e4/triton/python/triton/runtime/autotuner.py:99: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See triton-lang/triton#4496 for details. warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py::test_op[triton-fp16-bwd-128-1024-16-8] /data/users/pchen7e4/miniconda3/lib/python3.11/site-packages/torch/autograd/graph.py:824: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:181.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ================================================================================= 9 passed, 8 skipped, 4 warnings in 8.85s ===================== ``` Pull Request resolved: #635 Reviewed By: htyu Differential Revision: D86236535 Pulled By: pchen7e2 fbshipit-source-id: d17e708c39172e01351ec599cb927738236fbf87
1 parent c48a183 commit d7b915c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

third_party/tlx/run_all.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
echo "Hello, $USER! (Facebook-only)"
3+
echo "Hello! (Facebook-only)"
44

55
# Build
66
ask() {
@@ -90,10 +90,7 @@ read user_choice
9090
case $user_choice in
9191
c)
9292
echo "Verifying correctness of TLX tutorial kernels"
93-
for k in third_party/tlx/tutorials/*.py; do
94-
echo "Running $k"
95-
pytest $k
96-
done
93+
pytest third_party/tlx/tutorials/*.py
9794
;;
9895
p)
9996
echo "Measuring performance of TLX tutorial kernels"

third_party/tlx/tutorials/amd-gemm-pipelined.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import pytest
12
import torch
23

34
import triton
45
import triton.language as tl
56
import triton.language.extra.tlx as tlx
6-
from triton._internal_testing import is_cuda, is_hip_cdna2
7+
from triton._internal_testing import is_cuda, is_hip_cdna2, is_hip
78

89
DEVICE = triton.runtime.driver.active.get_active_torch_device()
910

@@ -238,6 +239,10 @@ def matmul(a, b):
238239
return c
239240

240241

242+
@pytest.mark.skipif(
243+
not is_hip(),
244+
reason="Requires AMD GPU",
245+
)
241246
def test_op():
242247
torch.manual_seed(0)
243248
a = torch.randn((8192, 8192), device=DEVICE, dtype=torch.float16)

0 commit comments

Comments
 (0)