Skip to content

Commit c54f988

Browse files
authored
[frontend] added overflow checks in debug mode (#4589)
1 parent 2c498ee commit c54f988

File tree

16 files changed

+248
-270
lines changed

16 files changed

+248
-270
lines changed

.github/workflows/integration-tests.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ jobs:
226226
- name: Install pip dependencies
227227
run: |
228228
python3 -m pip install --upgrade pip
229-
python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit
229+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
230230
- name: Install Triton
231231
env:
232232
TRITON_BUILD_WITH_CCACHE: "true"
@@ -250,8 +250,9 @@ jobs:
250250
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
251251
fi
252252
cd python/test/unit
253-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
253+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
254254
python3 -m pytest -s -n 8 language/test_subprocess.py
255+
python3 -m pytest -s -n 8 test_debug.py --forked
255256
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
256257
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
257258
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
@@ -407,7 +408,10 @@ jobs:
407408
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408409
cd python/test/unit
409410
pytest --capture=tee-sys -rfs -n 16 language runtime \
410-
--ignore=language/test_line_info.py
411+
--ignore=language/test_line_info.py \
412+
--ignore=test_debug.py
413+
# TODO: uncomment
414+
# pytest --capture=tee-sys -rfs test_debug.py
411415
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
412416
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
413417

.github/workflows/integration-tests.yml.in

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ jobs:
256256
- name: Install pip dependencies
257257
run: |
258258
python3 -m pip install --upgrade pip
259-
python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit
259+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
260260

261261
- name: Install Triton
262262
env:
@@ -284,8 +284,9 @@ jobs:
284284
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
285285
fi
286286
cd python/test/unit
287-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
287+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
288288
python3 -m pytest -s -n 8 language/test_subprocess.py
289+
python3 -m pytest -s -n 8 test_debug.py --forked
289290
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
290291
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
291292
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
@@ -403,7 +404,10 @@ jobs:
403404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404405
cd python/test/unit
405406
pytest --capture=tee-sys -rfs -n 16 language runtime \
406-
--ignore=language/test_line_info.py
407+
--ignore=language/test_line_info.py \
408+
--ignore=test_debug.py
409+
# TODO: uncomment
410+
# pytest --capture=tee-sys -rfs test_debug.py
407411
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
408412
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
409413

python/test/unit/language/assert_helper.py

Lines changed: 0 additions & 154 deletions
This file was deleted.

python/test/unit/language/test_subprocess.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
dir_path = os.path.dirname(os.path.realpath(__file__))
1010
print_path = os.path.join(dir_path, "print_helper.py")
11-
assert_path = os.path.join(dir_path, "assert_helper.py")
12-
13-
# TODO: bfloat16 after LLVM-15
14-
assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"]
15-
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
1611
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
1712

1813

@@ -120,59 +115,3 @@ def test_print(func_type: str, data_type: str, device: str):
120115
continue
121116
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
122117
assert all(delta == 0 for delta in diff.values())
123-
124-
125-
@pytest.mark.parametrize("func_type", assert_types)
126-
def test_assert(func_type: str, device: str):
127-
# The total number of elements in the 1-D tensor to assert on.
128-
N = 128
129-
130-
proc = subprocess.run(
131-
[sys.executable, assert_path, "test_assert", func_type, device],
132-
capture_output=True,
133-
env={**os.environ, "TRITON_DEBUG": "1"},
134-
)
135-
errs = proc.stderr.splitlines()
136-
num_errs = 0
137-
for err in errs:
138-
if "x != 0" in err.decode("utf-8", errors="ignore"):
139-
num_errs += 1
140-
141-
# Check for segfaults.
142-
assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs)
143-
144-
if func_type == "static_assert" or func_type == "device_assert_passes":
145-
assert num_errs == 0
146-
else:
147-
assert num_errs == N - 1
148-
149-
150-
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
151-
def test_assert_nested(caller_type, callee_type, device):
152-
# The total number of elements in the 1-D tensor to assert on.
153-
N = 128
154-
155-
proc = subprocess.run(
156-
[sys.executable, assert_path, "test_assert_nested", caller_type, callee_type, device],
157-
capture_output=True,
158-
)
159-
errs = proc.stderr.splitlines()
160-
num_errs = 0
161-
for err in errs:
162-
if "x != 0" in err.decode("utf-8", errors="ignore"):
163-
num_errs += 1
164-
if caller_type == "none":
165-
if callee_type == "true":
166-
assert num_errs == N - 1
167-
else:
168-
assert num_errs == 0
169-
elif caller_type == "true":
170-
if callee_type == "false":
171-
assert num_errs == 0
172-
else:
173-
assert num_errs == N - 1
174-
elif caller_type == "false":
175-
if callee_type == "true":
176-
assert num_errs == N - 1
177-
else:
178-
assert num_errs == 0

python/test/unit/runtime/test_cache.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -427,23 +427,18 @@ def kernel_add(a, b, o, N: tl.constexpr):
427427
def test_jit_debug(device) -> None:
428428

429429
@triton.jit
430-
def kernel_add(a, b, o, N: tl.constexpr):
431-
idx = tl.arange(0, N)
432-
tl.device_assert(idx < 32, "idx < 32")
433-
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
430+
def kernel(tmp):
431+
tl.device_assert(tl.load(tmp) == 1, "tmp == 1")
434432

435433
device = getattr(torch, device).current_device()
436-
assert len(kernel_add.cache[device]) == 0
437-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
438-
assert len(kernel_add.cache[device]) == 1
439-
kernel_add.debug = False
440-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
441-
assert len(kernel_add.cache[device]) == 2
442-
kernel_add.debug = True
443-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
444-
assert len(kernel_add.cache[device]) == 3
445-
bins = list(kernel_add.cache[device].values())
446-
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
434+
tmp = torch.tensor([1], dtype=torch.int32, device="cuda")
435+
assert len(kernel.cache[device]) == 0
436+
kernel[(1, )](tmp, debug=False)
437+
assert len(kernel.cache[device]) == 1
438+
kernel[(1, )](tmp, debug=True)
439+
assert len(kernel.cache[device]) == 2
440+
bins = list(kernel.cache[device].values())
441+
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
447442

448443

449444
@triton.jit

0 commit comments

Comments
 (0)