Skip to content

Commit 5c35a83

Browse files
Merge commit '37817d7773e419e89a955cdee17296d685df79b0'
2 parents 614173b + 37817d7 commit 5c35a83

File tree

19 files changed

+300
-44
lines changed

19 files changed

+300
-44
lines changed

.github/workflows/integration-tests.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ jobs:
262262
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
263263
fi
264264
cd python/test/unit
265-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
265+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py --ignore=test_address_sanitizer.py
266266
python3 -m pytest -s -n 8 language/test_subprocess.py
267267
python3 -m pytest -s -n 8 test_debug.py --forked
268268
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
@@ -429,14 +429,27 @@ jobs:
429429
cd python/test/unit
430430
pytest --capture=tee-sys -rfs -n 12 language runtime \
431431
--ignore=language/test_line_info.py \
432-
--ignore=test_debug.py
432+
--ignore=test_debug.py \
433+
--ignore=test_address_sanitizer.py
433434
# TODO: uncomment
434435
# pytest --capture=tee-sys -rfs test_debug.py
435436
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
436437
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
437438
438439
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
439440
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
441+
- name: Run asan tests on HIP
442+
run: |
443+
cd python/test/unit
444+
ulimit -s 1024
445+
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
446+
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
447+
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
448+
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
449+
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
450+
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
451+
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
452+
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
440453
- name: Run regression tests
441454
run: |
442455
# Reenable test_functional_regression.py once it's fixed

.github/workflows/integration-tests.yml.in

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ jobs:
300300
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
301301
fi
302302
cd python/test/unit
303-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
303+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py --ignore=test_address_sanitizer.py
304304
python3 -m pytest -s -n 8 language/test_subprocess.py
305305
python3 -m pytest -s -n 8 test_debug.py --forked
306306
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
@@ -309,7 +309,6 @@ jobs:
309309
python3 -m pytest -s hopper/test_flashattention.py
310310
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
311311
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
312-
313312
- name: Run interpreter tests
314313
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
315314
env:
@@ -416,15 +415,27 @@ jobs:
416415
cd python/test/unit
417416
pytest --capture=tee-sys -rfs -n 12 language runtime \
418417
--ignore=language/test_line_info.py \
419-
--ignore=test_debug.py
418+
--ignore=test_debug.py \
419+
--ignore=test_address_sanitizer.py
420420
# TODO: uncomment
421421
# pytest --capture=tee-sys -rfs test_debug.py
422422
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
423423
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
424424

425425
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
426426
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
427-
427+
- name: Run asan tests on HIP
428+
run: |
429+
cd python/test/unit
430+
ulimit -s 1024
431+
export PATH=$(find ~/.triton/llvm -name llvm-symbolizer -printf '%h\n'):$PATH
432+
export LD_LIBRARY_PATH=$(find /opt -name libclang_rt.asan-x86_64.so -printf '%h\n'):$LD_LIBRARY_PATH
433+
export LD_LIBRARY_PATH=$(find /opt -type d -wholename *lib/llvm/lib/asan):$LD_LIBRARY_PATH
434+
export LD_LIBRARY_PATH=$(find /usr -name libcaffe2_nvrtc.so -printf '%h\n'):$LD_LIBRARY_PATH
435+
export CLANG_ASAN_LIB=$(find /opt -name libclang_rt.asan-x86_64.so)
436+
export HIP_ASAN_LIB=$(find /opt -wholename *lib/asan/libamdhip64.so)
437+
ASAN_OPTIONS=detect_leaks=0,alloc_dealloc_mismatch=0 \
438+
LD_PRELOAD=$CLANG_ASAN_LIB:$HIP_ASAN_LIB python3 -m pytest -s test_address_sanitizer.py
428439
- name: Run regression tests
429440
run: |
430441
# Reenable test_functional_regression.py once it's fixed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
172172
separated values to be specified (eg
173173
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or
174174
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`).
175+
- `TRITON_ENABLE_ASAN=1` invokes the LLVM address sanitizer for
176+
memory leak and out of bounds access detection. Currently only supported on the AMD
177+
backend. This must be run using the ASAN libraries documented [here](https://rocm.docs.amd.com/projects/llvm-project/en/latest/conceptual/using-gpu-sanitizer.html).
178+
179+
When enabling the address sanitizer it is recommended to disable various memory caching strategies
180+
both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the
181+
memory fault where it originates. This can be done through the HSA_DISABLE_FRAGMENT_ALLOCATOR, AMD_PYTORCH_NO_CUDA_MEMORY_CACHING,
182+
and PYTORCH_NO_HIP_MEMORY_CACHING environment variables.
183+
175184
- `USE_IR_LOC={ttir,ttgir}` reparses the IR such that the location information
176185
will be the line number of the IR file with that particular extension,
177186
instead of line number of the python file. This can provide a direct mapping

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3131
"TRITON_HIP_STREAM_PREFETCH",
3232
"TRITON_HIP_USE_BLOCK_PINGPONG",
3333
"TRITON_LLVM_DEBUG_ONLY",
34+
"TRITON_ENABLE_ASAN",
3435
"USE_IR_LOC",
3536
"NVPTX_ENABLE_DUMP",
3637
"TRITON_INTEL_ADVANCED_PATH",

python/src/llvm.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "llvm/Target/TargetMachine.h"
2424
#include "llvm/Transforms/IPO/AlwaysInliner.h"
2525
#include "llvm/Transforms/InstCombine/InstCombine.h"
26+
#include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
27+
#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h"
2628
#include <csignal>
2729
#include <memory>
2830
#include <pybind11/pybind11.h>
@@ -217,7 +219,14 @@ void init_triton_llvm(py::module &&m) {
217219
.def("set_calling_conv", &llvm::Function::setCallingConv)
218220
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
219221
std::string &val) { fn->addFnAttr(name, val); })
220-
222+
.def("add_fn_asan_attr",
223+
[](llvm::Function *fn) {
224+
fn->addFnAttr(llvm::Attribute::SanitizeAddress);
225+
})
226+
.def("add_fn_target_feature",
227+
[](llvm::Function *fn, std::string &val) {
228+
fn->addFnAttr("target-features", val);
229+
})
221230
// Sets the nvvm.maxreg property on the given function.
222231
.def("set_nvvm_maxnreg",
223232
[](llvm::Function *fn, int maxnreg) {
@@ -377,6 +386,12 @@ void init_triton_llvm(py::module &&m) {
377386
fpm.addPass(BreakStructPhiNodesPass());
378387
fpm.addPass(InstCombinePass());
379388
});
389+
bool enableAddressSanitizer =
390+
mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN");
391+
if (enableAddressSanitizer) {
392+
AddressSanitizerOptions Opts;
393+
mpm.addPass(AddressSanitizerPass(Opts));
394+
}
380395
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
381396
mpm.run(*mod, mam);
382397
},
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
size = 4096
6+
x = torch.rand(size, device='cuda')
7+
y = torch.rand(size, device='cuda')
8+
output = torch.empty_like(x)
9+
n_elements = output.numel()
10+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
11+
12+
13+
@triton.jit
14+
def add_kernel(
15+
x_ptr,
16+
y_ptr,
17+
output_ptr,
18+
n_elements,
19+
BLOCK_SIZE: tl.constexpr,
20+
):
21+
pid = tl.program_id(axis=0)
22+
block_start = pid * BLOCK_SIZE
23+
#Set access to go out of bounds for ASAN test
24+
offsets = block_start + tl.arange(0, BLOCK_SIZE) + 1
25+
x = tl.load(x_ptr + offsets)
26+
y = tl.load(y_ptr + offsets)
27+
output = x + y
28+
tl.store(output_ptr + offsets, output)
29+
30+
31+
pgm = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
32+
amdgcn = pgm.asm['amdgcn']
33+
print(amdgcn)

python/test/unit/language/test_core.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,6 @@ def filter_layouts(layouts):
266266
return [l for l in layouts if is_layout_applicable(l)]
267267

268268

269-
def filter_layout_pairs(pairs):
270-
return [p for p in pairs if is_layout_applicable(p[0]) and is_layout_applicable(p[1])]
271-
272-
273269
@pytest.mark.interpreter
274270
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
275271
def test_empty_kernel(dtype_x, device):
@@ -5733,6 +5729,10 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
57335729
assert "stmatrix" in kernel.asm["ptx"]
57345730

57355731

5732+
def filter_layout_pairs(layout_pairs):
5733+
return [pair for pair in layout_pairs if is_layout_applicable(pair[0]) and is_layout_applicable(pair[1])]
5734+
5735+
57365736
mma_pairs = [
57375737
[
57385738
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
@@ -5774,6 +5774,54 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
57745774
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]),
57755775
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]),
57765776
],
5777+
[
5778+
WmmaLayout(1, [4, 4]),
5779+
WmmaLayout(1, [16, 1]),
5780+
],
5781+
[
5782+
WmmaLayout(1, [16, 1]),
5783+
WmmaLayout(1, [4, 4]),
5784+
],
5785+
[
5786+
WmmaLayout(2, [4, 4]),
5787+
WmmaLayout(2, [16, 1]),
5788+
],
5789+
[
5790+
WmmaLayout(2, [16, 1]),
5791+
WmmaLayout(2, [4, 4]),
5792+
],
5793+
[
5794+
MfmaLayout([2, 0], [2, 2], [32, 32], False),
5795+
MfmaLayout([2, 0], [4, 1], [32, 32], False),
5796+
],
5797+
[
5798+
MfmaLayout([2, 0], [4, 1], [32, 32], False),
5799+
MfmaLayout([2, 0], [2, 2], [32, 32], False),
5800+
],
5801+
[
5802+
MfmaLayout([2, 0], [2, 2], [32, 32], False),
5803+
MfmaLayout([2, 0], [4, 1], [32, 32], True),
5804+
],
5805+
[
5806+
MfmaLayout([2, 0], [4, 1], [32, 32], False),
5807+
MfmaLayout([2, 0], [2, 2], [32, 32], True),
5808+
],
5809+
[
5810+
MfmaLayout([2, 0], [4, 4], [16, 16], False),
5811+
MfmaLayout([2, 0], [16, 1], [16, 16], False),
5812+
],
5813+
[
5814+
MfmaLayout([2, 0], [16, 1], [16, 16], False),
5815+
MfmaLayout([2, 0], [4, 4], [16, 16], False),
5816+
],
5817+
[
5818+
MfmaLayout([2, 0], [4, 4], [16, 16], False),
5819+
MfmaLayout([2, 0], [16, 1], [16, 16], True),
5820+
],
5821+
[
5822+
MfmaLayout([2, 0], [16, 1], [16, 16], False),
5823+
MfmaLayout([2, 0], [4, 4], [16, 16], True),
5824+
],
57775825
[
57785826
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
57795827
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
@@ -5783,12 +5831,17 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
57835831
]
57845832

57855833

5786-
@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
5834+
@pytest.mark.parametrize("M, N", [[16, 16], [64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
57875835
@pytest.mark.parametrize("dtype", ['float16'])
57885836
@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs))
57895837
def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
5838+
if is_hip():
5839+
if isinstance(mma_pair[1], MfmaLayout) and (mma_pair[1].instr_shape[1] > M or mma_pair[1].instr_shape[1] > N):
5840+
pytest.skip("HIP do not fully support skinny tensor store")
5841+
57905842
src_layout, _ = mma_pair
57915843
num_warps = np.prod(src_layout.warps_per_cta)
5844+
warp_size = THREADS_PER_WARP
57925845

57935846
def do_test(src_layout, dst_layout):
57945847
layouts = f"""
@@ -5797,7 +5850,7 @@ def do_test(src_layout, dst_layout):
57975850
"""
57985851

57995852
ir = layouts + f"""
5800-
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32}} {{
5853+
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {warp_size} : i32}} {{
58015854
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
58025855
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
58035856
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #src}}>>

python/test/unit/language/test_tuple.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import triton
33
import triton.language as tl
4+
from typing import NamedTuple
45
import torch
56

67

@@ -99,3 +100,50 @@ def test_serialize(device="xpu"):
99100
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, ))
100101
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
101102
assert torch.equal(z, ref)
103+
104+
105+
class Function(NamedTuple):
106+
fn: tl.constexpr
107+
captured: tuple
108+
109+
110+
class Tensor(NamedTuple):
111+
ptr: any
112+
shape: tuple
113+
stride: tuple
114+
115+
116+
@triton.jit
117+
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
118+
offs_m = tl.arange(0, BLOCK_M)
119+
offs_n = tl.arange(0, BLOCK_N)
120+
mask = (offs_m[:, None] < Tensor.shape[0]) & (offs_n[None, :] < Tensor.shape[1])
121+
return mask
122+
123+
124+
@triton.jit
125+
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
126+
offs_m = tl.arange(0, BLOCK_M)
127+
offs_n = tl.arange(0, BLOCK_N)
128+
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
129+
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
130+
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
131+
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
132+
y = closure.fn(x, *closure.captured)
133+
tl.store(Ys, y, mask=_namedtuple_mask_func(Y, BLOCK_M, BLOCK_N))
134+
135+
136+
def test_namedtuple(device="cuda"):
137+
x = torch.randn((32, 32), dtype=torch.float32, device=device)
138+
y = torch.empty((16, 16), dtype=torch.float32, device=device)
139+
a = torch.tensor([5.2], dtype=torch.float32, device=device)
140+
141+
@triton.jit
142+
def mul(x, a):
143+
return x * tl.load(a)
144+
145+
function = Function(mul, (a, ))
146+
tx = Tensor(x, x.shape, x.stride())
147+
ty = Tensor(y, y.shape, y.stride())
148+
_namedtuple_kernel[(1, )](function, tx, ty, 64, 64)
149+
assert torch.allclose(y, x[:16, :16] * a)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import subprocess
3+
4+
import triton
5+
6+
7+
def is_hip():
8+
return triton.runtime.driver.active.get_current_target().backend == "hip"
9+
10+
11+
def test_address_sanitizer():
12+
if not is_hip():
13+
return #not supported on NV backend
14+
15+
# It is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch
16+
# This will give the address sanitizer the best chance at finding the memory fault where it originates,
17+
# otherwise it could be masked by writing past the end of a cached block within a larger allocation.
18+
os.environ["HSA_DISABLE_FRAGMENT_ALLOCATOR"] = "1"
19+
os.environ["AMD_PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"
20+
os.environ["PYTORCH_NO_HIP_MEMORY_CACHING"] = "1"
21+
os.environ["TRITON_ENABLE_ASAN"] = "1"
22+
23+
# HSA_XNACK here is required to set the xnack+ setting for the GPU at runtime.
24+
# If it is not set and the default xnack setting of the system is xnack-
25+
# a runtime error something like "No kernel image found" will occur. The system
26+
# xnack setting can be found through rocminfo. xnack+ is required for ASAN.
27+
# More information about xnack in general can be found here:
28+
# https://llvm.org/docs/AMDGPUUsage.html#target-features
29+
# https://rocm.docs.amd.com/en/docs-6.1.0/conceptual/gpu-memory.html
30+
os.environ["HSA_XNACK"] = "1"
31+
32+
out = subprocess.Popen(["python", "address_sanitizer_helper.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE)
33+
assert "Begin function __asan_report" in out.stdout.read().decode()
34+
assert "heap-buffer-overflow" in out.stderr.read().decode()

0 commit comments

Comments
 (0)