Skip to content

Commit 40cde44

Browse files
authored
Add regression test for df19f1d using problematic kernel from one of crashed models in #4410 (#4529)
When testing e2e models on PyTorch side, [an issue was found](#4410) that caused a crash. This issue was not detected by our testing, so the idea arose to isolate the problematic kernel from one model and add it to our test suite to be able to detect it faster and not allow regressions. The crash: ```bash # L0 build module failed. Log: IGC: Internal Compiler Error: Segmentation violation # Error during Intel loadBinary: Triton Error [ZE]: 0x70000004 # RuntimeError: Triton Error [ZE]: 0x70000004 ``` The commit that caused this issue: df19f1d --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 4bc28e2 commit 40cde44

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pathlib
2+
3+
import triton
4+
5+
6+
def test_regression_4441(device, tmp_path: pathlib.Path):
7+
ir = """
8+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
10+
tt.func public @triton_red_fused__softmax_backward_data_div_masked_fill_native_dropout_backward_threshold_backward_10(%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: f32) {
11+
%cst_1 = arith.constant dense<0> : tensor<64x4xi8, #blocked>
12+
%c4_i32 = arith.constant 4 : i32
13+
%c204_i32 = arith.constant 204 : i32
14+
%c0_i32 = arith.constant 0 : i32
15+
%cst_2 = arith.constant dense<1.11111116> : tensor<64x4xf32, #blocked>
16+
%cst_5 = arith.constant dense<204> : tensor<64x1xi32, #blocked>
17+
%0 = tt.get_program_id x : i32
18+
%1 = arith.muli %0, %c4_i32 : i32
19+
%4 = tt.splat %1 : i32 -> tensor<64x1xi32, #blocked>
20+
%6 = arith.cmpi slt, %4, %cst_5 : tensor<64x1xi32, #blocked>
21+
%13 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x4x!tt.ptr<f32>, #blocked>
22+
%14 = tt.broadcast %6 : tensor<64x1xi1, #blocked> -> tensor<64x4xi1, #blocked>
23+
%16 = tt.broadcast %4 : tensor<64x1xi32, #blocked> -> tensor<64x4xi32, #blocked>
24+
25+
%26 = tt.splat %arg3 : !tt.ptr<i8> -> tensor<64x4x!tt.ptr<i8>, #blocked>
26+
%29 = tt.splat %arg4 : f32 -> tensor<64x4xf32, #blocked>
27+
scf.for %arg7 = %c0_i32 to %c204_i32 step %c4_i32 : i32 {
28+
%40 = tt.load %26, %14, %cst_1 : tensor<64x4x!tt.ptr<i8>, #blocked>
29+
%41 = arith.cmpi ne, %40, %cst_1 : tensor<64x4xi8, #blocked>
30+
%43 = tt.addptr %13, %16 : tensor<64x4x!tt.ptr<f32>, #blocked>, tensor<64x4xi32, #blocked>
31+
%44 = tt.load %43, %14, %cst_2 : tensor<64x4x!tt.ptr<f32>, #blocked>
32+
%57 = tt.extern_elementwise %44, %cst_2, %29 {libname = "", libpath = "", pure = true, symbol = "__imf_fmaf"} : (tensor<64x4xf32, #blocked>, tensor<64x4xf32, #blocked>, tensor<64x4xf32, #blocked>) -> tensor<64x4xf32, #blocked>
33+
%58 = arith.select %41, %cst_2, %57 : tensor<64x4xi1, #blocked>, tensor<64x4xf32, #blocked>
34+
%59 = arith.divf %58, %29 : tensor<64x4xf32, #blocked>
35+
tt.store %43, %59, %14 : tensor<64x4x!tt.ptr<f32>, #blocked>
36+
}
37+
tt.return
38+
}
39+
}
40+
"""
41+
42+
temp_file = tmp_path / "test_regression_4441.ttgir"
43+
temp_file.write_text(ir)
44+
kernel = triton.compile(str(temp_file))
45+
46+
from triton.runtime.driver import driver
47+
device = driver.active.get_current_device()
48+
49+
# try to catch:
50+
# L0 build module failed. Log: IGC: Internal Compiler Error: Segmentation violation
51+
# Error during Intel loadBinary: Triton Error [ZE]: 0x70000004
52+
# RuntimeError: Triton Error [ZE]: 0x70000004
53+
module, function, n_regs, n_spills, n_max_threads = driver.active.utils.load_binary(
54+
kernel.name, kernel.kernel, kernel.metadata.shared, kernel.metadata.build_flags,
55+
not kernel.metadata.generate_native_code, device)

0 commit comments

Comments
 (0)