Skip to content

Commit 9fe51c3

Browse files
authored
Fix tritonintelgpu-remove-layout-conversions pass on block ptr example. (#3817)
Resolves issues #3816. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent b400b8d commit 9fe51c3

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
22762276
}
22772277
}
22782278

2279-
22802279
// -----
22812280
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
22822281
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}>
@@ -2300,7 +2299,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
23002299
}
23012300
}
23022301

2303-
23042302
// -----
23052303

23062304
// COM: Check that dpas layout can be propagated from dot op to atomic_rmw op
@@ -2406,3 +2404,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
24062404
tt.return
24072405
}
24082406
}
2407+
2408+
// -----
2409+
2410+
// COM: Reproducer for issue #3817 (to ensure that the compiler doesn't crash).
2411+
2412+
// CHECK: #[[$BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2413+
2414+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2415+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2416+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
2417+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
2418+
#smem = #ttg.shared_memory
2419+
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
2420+
tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} , %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32} , %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/jovyan/intel-xpu-backend-for-triton/python/tutorials/09-persistent-matmul.py":568:0
2421+
), %arg3: i32 {tt.divisibility = 16 : i32} , %arg4: i32 {tt.divisibility = 16 : i32} , %arg5: i32 {tt.divisibility = 16 : i32} ) {
2422+
// CHECK-LABEL: @matmul_kernel_descriptor_persistent
2423+
%0 = ub.poison : !tt.ptr<tensor<128x64xf16, #blocked1>>
2424+
%c448_i32 = arith.constant 448 : i32
2425+
%c8_i32 = arith.constant 8 : i32
2426+
%c128_i32 = arith.constant 128 : i32
2427+
%c64_i32 = arith.constant 64 : i32
2428+
%c1_i64 = arith.constant 1 : i64
2429+
%c0_i32 = arith.constant 0 : i32
2430+
%c1_i32 = arith.constant 1 : i32
2431+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
2432+
%9 = arith.extsi %arg5 : i32 to i64
2433+
%10 = arith.extsi %arg4 : i32 to i64
2434+
%13 = arith.extsi %arg3 : i32 to i64
2435+
// CHECK: scf.for
2436+
%19:11 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg7 = %c1_i32, %arg8 = %c448_i32, %arg9 = %c448_i32, %arg10 = %c0_i32, %arg11 = %cst_0, %arg12 = %0, %arg13 = %0, %arg14 = %c0_i32, %arg15 = %c0_i32, %arg16 = %0, %arg17 = %0) -> (i32, i32, i32, i32, tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, i32, i32, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>) : i32 {
2437+
%20 = arith.addi %arg7, %c1_i32 : i32
2438+
%21 = arith.subi %c64_i32, %c1_i32 : i32
2439+
%22 = arith.cmpi eq, %arg7, %21 : i32
2440+
%23 = arith.select %22, %c0_i32, %20 : i32
2441+
// CHECK: scf.if
2442+
%26:7 = scf.if %22 -> (i32, i32, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, i32) {
2443+
%41 = arith.addi %arg8, %c448_i32 : i32
2444+
%42 = arith.divsi %41, %c8_i32 : i32
2445+
%43 = arith.muli %42, %c8_i32 : i32
2446+
%44 = arith.subi %c128_i32, %43 : i32
2447+
%45 = arith.minsi %44, %c8_i32 : i32
2448+
%46 = arith.remsi %41, %45 : i32
2449+
%47 = arith.addi %43, %46 : i32
2450+
%48 = arith.remsi %41, %c8_i32 : i32
2451+
%49 = arith.divsi %48, %45 : i32
2452+
%50 = arith.muli %47, %c128_i32 : i32
2453+
%51 = arith.muli %49, %c128_i32 : i32
2454+
%52 = tt.make_tensor_ptr %arg0, [%13, %9], [%9, %c1_i64], [%50, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
2455+
triton_intel_gpu.prefetch %52 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
2456+
%53 = tt.make_tensor_ptr %arg1, [%10, %9], [%9, %c1_i64], [%51, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
2457+
triton_intel_gpu.prefetch %53 {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
2458+
scf.yield %50, %51, %52, %53, %52, %53, %41 : i32, i32, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, i32
2459+
} else {
2460+
scf.yield %arg14, %arg15, %arg16, %arg17, %arg12, %arg13, %arg8 : i32, i32, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, i32
2461+
}
2462+
%29 = tt.make_tensor_ptr %arg0, [%13, %9], [%9, %c1_i64], [%26#0, %c64_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
2463+
%30 = tt.make_tensor_ptr %arg1, [%10, %9], [%9, %c1_i64], [%26#1, %c64_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
2464+
%31 = tt.load %26#4 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
2465+
%32 = tt.load %26#5 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
2466+
%33 = ttg.local_alloc %32 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
2467+
%34 = ttg.memdesc_trans %33 {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem>
2468+
%35 = ttg.local_load %34 : !ttg.memdesc<64x128xf16, #shared1, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
2469+
%36 = ttg.convert_layout %31 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
2470+
%37 = tt.dot %36, %35, %arg11, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2471+
scf.yield %23, %26#6, %c0_i32, %c1_i32, %37, %29, %30, %26#0, %26#1, %26#2, %26#3 : i32, i32, i32, i32, tensor<128x128xf32, #mma>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, i32, i32, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>
2472+
}
2473+
tt.return
2474+
}
2475+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -966,10 +966,19 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
966966
auto it = layout.find(res);
967967
assert(it != layout.end());
968968

969-
auto oldType = cast<RankedTensorType>(res.getType());
970-
auto newType = RankedTensorType::get(
971-
oldType.getShape(), oldType.getElementType(), it->second);
972-
newTypes.push_back(newType);
969+
Type resType = res.getType();
970+
if (auto oldType = dyn_cast<RankedTensorType>(resType)) {
971+
auto newType = RankedTensorType::get(
972+
oldType.getShape(), oldType.getElementType(), it->second);
973+
newTypes.push_back(newType);
974+
} else if (auto ptrType = dyn_cast<PointerType>(resType)) {
975+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
976+
auto newType = triton::PointerType::get(
977+
RankedTensorType::get(tensorType.getShape(),
978+
tensorType.getElementType(), it->second),
979+
ptrType.getAddressSpace());
980+
newTypes.push_back(newType);
981+
}
973982
}
974983
}
975984
scf::IfOp newIfOp =

0 commit comments

Comments
 (0)