Skip to content

Commit 428a24f

Browse files
Merge commit 'c54f9882aff504bf2ab62d0ba037fb042204dc90'
2 parents 14a7cf4 + c54f988 commit 428a24f

File tree

17 files changed

+250
-271
lines changed

17 files changed

+250
-271
lines changed

.github/workflows/integration-tests.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ jobs:
226226
- name: Install pip dependencies
227227
run: |
228228
python3 -m pip install --upgrade pip
229-
python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit
229+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
230230
- name: Install Triton
231231
env:
232232
TRITON_BUILD_WITH_CCACHE: "true"
@@ -250,8 +250,9 @@ jobs:
250250
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
251251
fi
252252
cd python/test/unit
253-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
253+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
254254
python3 -m pytest -s -n 8 language/test_subprocess.py
255+
python3 -m pytest -s -n 8 test_debug.py --forked
255256
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
256257
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
257258
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
@@ -407,7 +408,10 @@ jobs:
407408
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408409
cd python/test/unit
409410
pytest --capture=tee-sys -rfs -n 16 language runtime \
410-
--ignore=language/test_line_info.py
411+
--ignore=language/test_line_info.py \
412+
--ignore=test_debug.py
413+
# TODO: uncomment
414+
# pytest --capture=tee-sys -rfs test_debug.py
411415
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
412416
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
413417

.github/workflows/integration-tests.yml.in

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ jobs:
256256
- name: Install pip dependencies
257257
run: |
258258
python3 -m pip install --upgrade pip
259-
python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit
259+
python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit
260260

261261
- name: Install Triton
262262
env:
@@ -284,8 +284,9 @@ jobs:
284284
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
285285
fi
286286
cd python/test/unit
287-
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
287+
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
288288
python3 -m pytest -s -n 8 language/test_subprocess.py
289+
python3 -m pytest -s -n 8 test_debug.py --forked
289290
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
290291
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
291292
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
@@ -403,7 +404,10 @@ jobs:
403404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404405
cd python/test/unit
405406
pytest --capture=tee-sys -rfs -n 16 language runtime \
406-
--ignore=language/test_line_info.py
407+
--ignore=language/test_line_info.py \
408+
--ignore=test_debug.py
409+
# TODO: uncomment
410+
# pytest --capture=tee-sys -rfs test_debug.py
407411
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
408412
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
409413

python/test/unit/language/assert_helper.py

Lines changed: 0 additions & 154 deletions
This file was deleted.

python/test/unit/language/test_subprocess.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010

1111
dir_path = os.path.dirname(os.path.realpath(__file__))
1212
print_path = os.path.join(dir_path, "print_helper.py")
13-
assert_path = os.path.join(dir_path, "assert_helper.py")
14-
15-
# TODO: bfloat16 after LLVM-15
16-
assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"]
17-
nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]]
1813
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
1914

2015

@@ -124,59 +119,3 @@ def test_print(func_type: str, data_type: str, device: str):
124119
continue
125120
print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)')
126121
assert all(delta == 0 for delta in diff.values())
127-
128-
129-
@pytest.mark.parametrize("func_type", assert_types)
130-
def test_assert(func_type: str, device: str):
131-
# The total number of elements in the 1-D tensor to assert on.
132-
N = 128
133-
134-
proc = subprocess.run(
135-
[sys.executable, assert_path, "test_assert", func_type, device],
136-
capture_output=True,
137-
env={**os.environ, "TRITON_DEBUG": "1"},
138-
)
139-
errs = proc.stderr.splitlines()
140-
num_errs = 0
141-
for err in errs:
142-
if "x != 0" in err.decode("utf-8", errors="ignore"):
143-
num_errs += 1
144-
145-
# Check for segfaults.
146-
assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs)
147-
148-
if func_type == "static_assert" or func_type == "device_assert_passes":
149-
assert num_errs == 0
150-
else:
151-
assert num_errs == N - 1
152-
153-
154-
@pytest.mark.parametrize("caller_type, callee_type", nested_types)
155-
def test_assert_nested(caller_type, callee_type, device):
156-
# The total number of elements in the 1-D tensor to assert on.
157-
N = 128
158-
159-
proc = subprocess.run(
160-
[sys.executable, assert_path, "test_assert_nested", caller_type, callee_type, device],
161-
capture_output=True,
162-
)
163-
errs = proc.stderr.splitlines()
164-
num_errs = 0
165-
for err in errs:
166-
if "x != 0" in err.decode("utf-8", errors="ignore"):
167-
num_errs += 1
168-
if caller_type == "none":
169-
if callee_type == "true":
170-
assert num_errs == N - 1
171-
else:
172-
assert num_errs == 0
173-
elif caller_type == "true":
174-
if callee_type == "false":
175-
assert num_errs == 0
176-
else:
177-
assert num_errs == N - 1
178-
elif caller_type == "false":
179-
if callee_type == "true":
180-
assert num_errs == N - 1
181-
else:
182-
assert num_errs == 0

python/test/unit/runtime/test_cache.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -427,23 +427,18 @@ def kernel_add(a, b, o, N: tl.constexpr):
427427
def test_jit_debug(device) -> None:
428428

429429
@triton.jit
430-
def kernel_add(a, b, o, N: tl.constexpr):
431-
idx = tl.arange(0, N)
432-
tl.device_assert(idx < 32, "idx < 32")
433-
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
430+
def kernel(tmp):
431+
tl.device_assert(tl.load(tmp) == 1, "tmp == 1")
434432

435433
device = getattr(torch, device).current_device()
436-
assert len(kernel_add.cache[device]) == 0
437-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
438-
assert len(kernel_add.cache[device]) == 1
439-
kernel_add.debug = False
440-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
441-
assert len(kernel_add.cache[device]) == 2
442-
kernel_add.debug = True
443-
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
444-
assert len(kernel_add.cache[device]) == 3
445-
bins = list(kernel_add.cache[device].values())
446-
assert bins[2].asm['ttir'] != bins[1].asm['ttir']
434+
tmp = torch.tensor([1], dtype=torch.int32, device=device)
435+
assert len(kernel.cache[device]) == 0
436+
kernel[(1, )](tmp, debug=False)
437+
assert len(kernel.cache[device]) == 1
438+
kernel[(1, )](tmp, debug=True)
439+
assert len(kernel.cache[device]) == 2
440+
bins = list(kernel.cache[device].values())
441+
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
447442

448443

449444
@triton.jit

0 commit comments

Comments
 (0)