Skip to content

Commit 659470b

Browse files
authored
Propagate mma layout to atomic_rmw op (#2312)
This change help resolve issue [#1716](#1716). Propagating mma layout from dot to atomic_rmw op help eliminating `convert_layout` op from/to large size mma layout, which requires oversized shared memory.
1 parent 59edf2c commit 659470b

File tree

2 files changed

+142
-1
lines changed

2 files changed

+142
-1
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,3 +2297,112 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
22972297
tt.return %3 : tensor<128x256xf32, #blocked>
22982298
}
22992299
}
2300+
2301+
2302+
// -----
2303+
2304+
// COM: Check that dpas layout can be propagated from dot op to atomic_rmw op
2305+
// CHECK-NOT: #triton_gpu.blocked<{.*}>
2306+
// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2307+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [32], order = [0]}>
2308+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [32, 1], order = [1, 0]}>
2309+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
2310+
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
2311+
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 32], order = [0, 1]}>
2312+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2313+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
2314+
// CHECK-LABEL: tt.func public @propagate_mma_to_atomic_rmw
2315+
tt.func public @propagate_mma_to_atomic_rmw(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
2316+
%c0_i32 = arith.constant 0 : i32
2317+
%c1_i64 = arith.constant 1 : i64
2318+
%c32_i32 = arith.constant 32 : i32
2319+
%c128_i32 = arith.constant 128 : i32
2320+
%c256_i32 = arith.constant 256 : i32
2321+
%c4096_i32 = arith.constant 4096 : i32
2322+
%c4096_i64 = arith.constant 4096 : i64
2323+
%cst = arith.constant dense<4096> : tensor<256xi32, #blocked>
2324+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2>
2325+
%0 = tt.get_program_id x : i32
2326+
%1 = tt.get_program_id y : i32
2327+
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2328+
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2329+
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #blocked3>>
2330+
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #blocked2>>
2331+
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
2332+
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>) : i32 {
2333+
%47 = tt.load %arg5 : !tt.ptr<tensor<256x32xbf16, #blocked3>>
2334+
%48 = tt.load %arg6 : !tt.ptr<tensor<32x256xbf16, #blocked2>>
2335+
// CHEKC-NOT: triton_gpu.convert_layout
2336+
%49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma>
2337+
%50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2338+
%51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
2339+
%52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
2340+
%53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2>
2341+
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2342+
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2343+
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2344+
%54 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #blocked3>>
2345+
%55 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #blocked2>>
2346+
scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>
2347+
}
2348+
%16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
2349+
%32 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked2>
2350+
%38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked>
2351+
// CHEKC-NOT: triton_gpu.convert_layout
2352+
%39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>
2353+
%40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4>
2354+
%41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2>
2355+
%42 = tt.broadcast %41 : tensor<1x256xi1, #blocked2> -> tensor<256x256xi1, #blocked2>
2356+
// CHECK: %[[VAL_5:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
2357+
%46 = tt.atomic_rmw fadd, acq_rel, gpu, %32, %15#0, %42 : (tensor<256x256x!tt.ptr<f32>, #blocked2>, tensor<256x256xf32, #blocked2>, tensor<256x256xi1, #blocked2>) -> tensor<256x256xf32, #blocked2>
2358+
tt.return
2359+
}
2360+
}
2361+
2362+
2363+
// -----
2364+
2365+
// COM: Check that bare atomic_rmw op with blocked layout can still be propagated to dpas layout
2366+
// COM: Blocked layout will not backpropagate to overwrite dpas layout
2367+
// CHECK-NOT: #triton_gpu.blocked<{.*}>
2368+
// CHECK: #[[$DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2369+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
2370+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2371+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
2372+
// CHECK-LABEL: tt.func public @bare_atomic_with_blocked_layout
2373+
tt.func public @bare_atomic_with_blocked_layout(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
2374+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
2375+
%cst_0 = arith.constant dense<3072> : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
2376+
%c1_i64 = arith.constant 1 : i64
2377+
%c0_i32 = arith.constant 0 : i32
2378+
%c128_i32 = arith.constant 128 : i32
2379+
%c4096_i64 = arith.constant 4096 : i64
2380+
%c4096_i32 = arith.constant 4096 : i32
2381+
%0 = tt.get_program_id x : i32
2382+
%1 = tt.get_program_id y : i32
2383+
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
2384+
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
2385+
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>) : i32 {
2386+
%41 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
2387+
%42 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
2388+
%43 = tt.load %arg5 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
2389+
%44 = tt.load %arg6 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
2390+
%45 = tt.dot %43, %44, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
2391+
scf.yield %45, %41, %42 : tensor<256x256xf32, #mma>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
2392+
}
2393+
%18 = tt.splat %0 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
2394+
%28 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #mma>
2395+
%30 = arith.cmpi slt, %18, %cst_0 : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
2396+
%31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<256x1xi1, #mma>
2397+
%34 = tt.broadcast %31 : tensor<256x1xi1, #mma> -> tensor<256x256xi1, #mma>
2398+
// CHECK-NOT: triton_gpu.convert_layout
2399+
%37 = triton_gpu.convert_layout %28 : tensor<256x256x!tt.ptr<f32>, #mma> -> tensor<256x256x!tt.ptr<f32>, #blocked>
2400+
%38 = triton_gpu.convert_layout %15#0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
2401+
%39 = triton_gpu.convert_layout %34 : tensor<256x256xi1, #mma> -> tensor<256x256xi1, #blocked>
2402+
// CHECK: %[[VAL_0:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, {{.*}} : (tensor<256x256x!tt.ptr<f32>, #[[$DPAS]]>, tensor<256x256xf32, #[[$DPAS]]>, tensor<256x256xi1, #[[$DPAS]]>) -> tensor<256x256xf32, #[[$DPAS]]>
2403+
%40 = tt.atomic_rmw fadd, acq_rel, gpu, %37, %38, %39 : (tensor<256x256x!tt.ptr<f32>, #blocked>, tensor<256x256xf32, #blocked>, tensor<256x256xi1, #blocked>) -> tensor<256x256xf32, #blocked>
2404+
// CHECK-NOT: triton_gpu.convert_layout
2405+
%41 = triton_gpu.convert_layout %40 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma>
2406+
tt.return
2407+
}
2408+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
1414

1515
#include "triton/Analysis/Utility.h"
16+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
1617
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1718

1819
namespace mlir::triton::gpu::intel {
@@ -252,6 +253,19 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
252253
}
253254
}
254255
}
256+
257+
// HACK: we want to propagate mma layout to the atomic_rmw op, so we do
258+
// not need an extra ConvertLayout Op to convert layout from mma to other
259+
// layouts, which may consume excessive shared local memory.
260+
// TODO: we need to investigate the performance impact of atomic_rmw op
261+
// with mma layout, compared with ConvertLayout Op + atomic_rmw op with
262+
// blocked layout.
263+
if (auto atomicOp = dyn_cast<AtomicRMWOp>(op)) {
264+
auto tensorType =
265+
dyn_cast<RankedTensorType>(atomicOp.getResult().getType());
266+
if (tensorType && isa<MmaEncodingTrait>(tensorType.getEncoding()))
267+
return true;
268+
}
255269
bool isMMAV3 =
256270
isa<NvidiaMmaEncodingAttr>(encoding) &&
257271
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
@@ -272,6 +286,11 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
272286
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
273287
if (!forOp)
274288
continue;
289+
for (OpOperand &operand : forOp.getResult(0).getUses()) {
290+
Operation *def = operand.get().getDefiningOp();
291+
if (def && (seen.insert(operand.get()).second == true))
292+
queue.push_back(operand.get());
293+
}
275294
for (OpOperand &operand : yield->getOpOperands()) {
276295
Operation *def = operand.get().getDefiningOp();
277296
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
@@ -288,8 +307,12 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
288307
bool isLayoutAnchor(Operation *op) {
289308
if (isa<LoadOp, StoreOp>(op))
290309
return ttgi::isExpensiveLoadOrStore(op);
291-
if (isa<DotOp, AtomicRMWOp, AtomicCASOp>(op))
310+
if (isa<DotOp, AtomicCASOp>(op))
292311
return true;
312+
if (isa<AtomicRMWOp>(op))
313+
if (auto tensorType =
314+
dyn_cast<RankedTensorType>(op->getResult(0).getType()))
315+
return isa<MmaEncodingTrait>(tensorType.getEncoding());
293316

294317
// Heuristic: Mark permuting reshape as a layout anchor. Its dst can be
295318
// anything, so it stops forward-propagation of layouts. We rely on the
@@ -402,6 +425,15 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
402425
setEncoding({afterArg, result}, info, changed, user);
403426
continue;
404427
}
428+
if (auto atomicRMWOp = dyn_cast<AtomicRMWOp>(user)) {
429+
bool isBlockedOrMma = std::all_of(
430+
info.encodings.begin(), info.encodings.end(), [](Attribute encoding) {
431+
return isa<BlockedEncodingAttr, MmaEncodingTrait>(encoding);
432+
});
433+
if (isBlockedOrMma)
434+
setEncoding(user->getResults(), info, changed, user);
435+
continue;
436+
}
405437
if (user->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
406438
user->hasTrait<OpTrait::Elementwise>() ||
407439
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,

0 commit comments

Comments
 (0)