diff --git a/python/test/unit/intel/test_regressions.py b/python/test/unit/intel/test_regressions.py index 604977f449..f62dccc98b 100644 --- a/python/test/unit/intel/test_regressions.py +++ b/python/test/unit/intel/test_regressions.py @@ -53,3 +53,57 @@ def test_regression_4441(device, tmp_path: pathlib.Path): module, function, n_regs, n_spills, n_max_threads = driver.active.utils.load_binary( kernel.name, kernel.kernel, kernel.metadata.shared, kernel.metadata.build_flags, not kernel.metadata.generate_native_code, device) + + +def test_kernel_from_09_tutorial(device, tmp_path: pathlib.Path): + # although the kernel is taken from the arl-h machine, the problem with it is also reproduced on pvc + ir = """ +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 8 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %42 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked1> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %46 = tt.expand_dims %18 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %50 = tt.broadcast %46 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %52 = tt.splat %arg0 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %53 = tt.addptr %52, %50 : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + + %85 = tt.load %42: tensor<128x64x!tt.ptr, #blocked1> + %86 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked1> + %87 = arith.cmpi slt, %45, %86 : tensor<64x1xi32, #blocked1> + %88 = tt.broadcast %87 : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %89 = tt.load %53, %88, %cst_0 : tensor<64x128x!tt.ptr, #blocked1> + %91 = ttg.local_alloc %85 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #shared, #smem> + %92 = ttg.local_load %91 : !ttg.memdesc<128x64xf32, #shared, #smem> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %94 = ttg.local_alloc %89 : (tensor<64x128xf32, #blocked1>) -> !ttg.memdesc<64x128xf32, #shared1, #smem> + %cst_test2 = arith.constant dense<1.11111116> : tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %96 = tt.dot %92, %cst_test2, %cst, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + + %78 = ttg.convert_layout %96 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + """ + + temp_file = tmp_path / "test_kernel_from_09_tutorial.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + from triton.runtime.driver import driver + device = driver.active.get_current_device() + + # try to catch: + # L0 build module failed. Log: IGC: Internal Compiler Error: Segmentation violation + # Error during Intel loadBinary: Triton Error [ZE]: 0x78000011 + # RuntimeError: Triton Error [ZE]: 0x78000011 + module, function, n_regs, n_spills, n_max_threads = driver.active.utils.load_binary( + kernel.name, kernel.kernel, kernel.metadata.shared, kernel.metadata.build_flags, + not kernel.metadata.generate_native_code, device)