Skip to content

Commit 3c0eae6

Browse files
[FlexAttention] Propagate MMA layout to tl.store to remove layout conversion (#4354)
This change propagate dpas layout to store op when matmul works with tensor pointer. --------- Co-authored-by: Lu,Chengjun <[email protected]>
1 parent 0cc2c18 commit 3c0eae6

File tree

2 files changed

+129
-3
lines changed

2 files changed

+129
-3
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,3 +2601,73 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
26012601
tt.return
26022602
}
26032603
}
2604+
2605+
// -----
2606+
2607+
// COM: Test that the DPAS layout is propagated to the store operation with tensor pointers.
2608+
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2609+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
2610+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
2611+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
2612+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
2613+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
2614+
// CHECK-LABEL: matmul_kernel_with_tensor_pointer
2615+
tt.func public @matmul_kernel_with_tensor_pointer(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32) {
2616+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked>
2617+
%c1_i32 = arith.constant 1 : i32
2618+
%c0_i32 = arith.constant 0 : i32
2619+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #blocked1>
2620+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x32xf16, #blocked2>
2621+
%cst_2 = arith.constant dense<32> : tensor<256x32xi32, #blocked2>
2622+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
2623+
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
2624+
%18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
2625+
%19 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x32xi32, #blocked2>
2626+
%23 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x32x!tt.ptr<f16>, #blocked2>
2627+
%26 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
2628+
%28 = tt.expand_dims %26 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1>
2629+
%35 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x256x!tt.ptr<f16>, #blocked1>
2630+
%40 = tt.splat %arg5 : i32 -> tensor<32x256xi32, #blocked1>
2631+
%41:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %23, %arg12 = %35) -> (tensor<256x256xf32, #blocked>, tensor<256x32x!tt.ptr<f16>, #blocked2>, tensor<32x256x!tt.ptr<f16>, #blocked1>) : i32 {
2632+
%62 = tt.splat %arg9 : i32 -> tensor<1x32xi32, #blocked2>
2633+
%63 = arith.cmpi slt, %19, %62 : tensor<1x32xi32, #blocked2>
2634+
%64 = tt.broadcast %63 : tensor<1x32xi1, #blocked2> -> tensor<256x32xi1, #blocked2>
2635+
%65 = tt.load %arg11, %64, %cst_1 : tensor<256x32x!tt.ptr<f16>, #blocked2>
2636+
%66 = tt.splat %arg5 : i32 -> tensor<32x1xi32, #blocked1>
2637+
%67 = arith.cmpi slt, %28, %66 : tensor<32x1xi32, #blocked1>
2638+
%68 = tt.broadcast %67 : tensor<32x1xi1, #blocked1> -> tensor<32x256xi1, #blocked1>
2639+
%69 = tt.load %arg12, %68, %cst_0 : tensor<32x256x!tt.ptr<f16>, #blocked1>
2640+
%70 = ttg.convert_layout %arg10 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma>
2641+
%71 = ttg.convert_layout %65 : tensor<256x32xf16, #blocked2> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
2642+
%72 = ttg.convert_layout %69 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
2643+
%73 = tt.dot %71, %72, %70, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
2644+
%74 = ttg.convert_layout %73 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
2645+
%75 = tt.addptr %arg11, %cst_2 : tensor<256x32x!tt.ptr<f16>, #blocked2>, tensor<256x32xi32, #blocked2>
2646+
%76 = tt.addptr %arg12, %40 : tensor<32x256x!tt.ptr<f16>, #blocked1>, tensor<32x256xi32, #blocked1>
2647+
scf.yield %74, %75, %76 : tensor<256x256xf32, #blocked>, tensor<256x32x!tt.ptr<f16>, #blocked2>, tensor<32x256x!tt.ptr<f16>, #blocked1>
2648+
}
2649+
// CHECK: [[SCF:%.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, {{.*}}) : i32 {
2650+
// CHECK: tt.expand_dims {{.*}} -> tensor<256x1xi32, #[[$DPAS]]>
2651+
// CHECK-NOT: ttg.convert_layout
2652+
// CHECK: [[RES:%.*]] = arith.truncf [[SCF]]#0 : tensor<256x256xf32, #[[$DPAS]]> to tensor<256x256xf16, #[[$DPAS]]>
2653+
// CHECK: tt.store {{.*}}, [[RES]], {{.*}} : tensor<256x256x!tt.ptr<f16>, #[[$DPAS]]>
2654+
%42 = tt.expand_dims %3 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
2655+
%45 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
2656+
%46 = tt.addptr %45, %42 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
2657+
%47 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1>
2658+
%48 = tt.broadcast %46 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x256x!tt.ptr<f16>, #blocked1>
2659+
%49 = tt.broadcast %47 : tensor<1x256xi32, #blocked1> -> tensor<256x256xi32, #blocked1>
2660+
%50 = tt.addptr %48, %49 : tensor<256x256x!tt.ptr<f16>, #blocked1>, tensor<256x256xi32, #blocked1>
2661+
%51 = tt.splat %arg3 : i32 -> tensor<256x1xi32, #blocked1>
2662+
%52 = arith.cmpi slt, %42, %51 : tensor<256x1xi32, #blocked1>
2663+
%53 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #blocked1>
2664+
%54 = arith.cmpi slt, %47, %53 : tensor<1x256xi32, #blocked1>
2665+
%55 = tt.broadcast %52 : tensor<256x1xi1, #blocked1> -> tensor<256x256xi1, #blocked1>
2666+
%56 = tt.broadcast %54 : tensor<1x256xi1, #blocked1> -> tensor<256x256xi1, #blocked1>
2667+
%57 = arith.andi %55, %56 : tensor<256x256xi1, #blocked1>
2668+
%58 = arith.truncf %41#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
2669+
%59 = ttg.convert_layout %58 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked1>
2670+
tt.store %50, %59, %57 : tensor<256x256x!tt.ptr<f16>, #blocked1>
2671+
tt.return
2672+
}
2673+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class LayoutPropagation {
9797
void rewriteAssertOp(AssertOp assertOp);
9898
// Rewrite a StoreOp with the forwarded DPAS layout if applicable.
9999
// return true if the StoreOp has been rewritten.
100+
bool rewriteTensorPtrStoreOp(StoreOp storeOp);
100101
bool rewriteStoreOp(StoreOp storeOp);
101102
Operation *cloneElementwise(OpBuilder &rewriter, Operation *op,
102103
Attribute encoding);
@@ -239,7 +240,7 @@ void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info,
239240
bool hasChanged = false;
240241
for (auto encoding : info.encodings) {
241242
Attribute dstEncoding;
242-
if (isa<ConvertLayoutOp>(op)) {
243+
if (isa<StoreOp, ConvertLayoutOp>(op)) {
243244
// Try to remove the convert by making the dst encoding match the source
244245
// encoding.
245246
dstEncoding = encoding;
@@ -303,6 +304,28 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
303304
setEncoding(user->getResults(), info, changed, user);
304305
continue;
305306
}
307+
if (auto storeOp = dyn_cast<StoreOp>(user)) {
308+
auto checkMMAorMMADerived = [](Attribute encoding) {
309+
bool isMMAorMMADerived = isa<MmaEncodingTrait>(encoding);
310+
if (isa<SliceEncodingAttr>(encoding)) {
311+
isMMAorMMADerived |= isa<MmaEncodingTrait>(
312+
cast<SliceEncodingAttr>(encoding).getParent());
313+
} else if (isa<DotOperandEncodingAttr>(encoding)) {
314+
isMMAorMMADerived |= isa<MmaEncodingTrait>(
315+
cast<DotOperandEncodingAttr>(encoding).getParent());
316+
}
317+
return isMMAorMMADerived;
318+
};
319+
if (llvm::all_of(info.encodings, checkMMAorMMADerived)) {
320+
if (storeOp.getMask())
321+
setEncoding({storeOp.getPtr(), storeOp.getValue(), storeOp.getMask()},
322+
info, changed, user);
323+
else
324+
setEncoding({storeOp.getPtr(), storeOp.getValue()}, info, changed,
325+
user);
326+
}
327+
continue;
328+
}
306329
if (user->hasTrait<OpTrait::SameOperandsAndResultEncoding>() ||
307330
user->hasTrait<OpTrait::Elementwise>() ||
308331
isa<ReduceOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
@@ -421,7 +444,6 @@ void LayoutPropagation::rewriteRegion(Region &region) {
421444
if (rewriteStoreOp(storeOp))
422445
continue;
423446
}
424-
425447
// If we don't need to rewrite the op we still need to remap the
426448
// operands.
427449
for (OpOperand &operand : op.getOpOperands()) {
@@ -719,7 +741,7 @@ static void updateAdvanceOpChain(AdvanceOp advanceOp, StoreOp storeOp,
719741
}
720742
}
721743

722-
bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
744+
bool LayoutPropagation::rewriteTensorPtrStoreOp(StoreOp storeOp) {
723745
// Disable 2D block store on LTS.
724746
if (!storeOp->getParentOfType<ModuleOp>()->hasAttr(
725747
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName()))
@@ -831,6 +853,40 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
831853
return true;
832854
}
833855

856+
bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
857+
if (rewriteTensorPtrStoreOp(storeOp))
858+
return true;
859+
860+
Operation *op = storeOp.getOperation();
861+
llvm::MutableArrayRef<OpOperand> operands = op->getOpOperands();
862+
// Check if all store op operands should use new encoding.
863+
bool usesNewEncoding = llvm::all_of(operands, [&](OpOperand &operand) {
864+
auto it = layouts.find(operand.get());
865+
if (it == layouts.end())
866+
return false;
867+
LayoutInfo &info = it->second;
868+
assert(info.encodings.size() == 1 &&
869+
"we should have resolved to a single encoding");
870+
auto encoding =
871+
cast<RankedTensorType>(operand.get().getType()).getEncoding();
872+
return encoding != *info.encodings.begin();
873+
});
874+
if (usesNewEncoding) {
875+
for (OpOperand &operand : op->getOpOperands()) {
876+
auto it = layouts.find(operand.get());
877+
Attribute encoding =
878+
cast<RankedTensorType>(operand.get().getType()).getEncoding();
879+
LayoutInfo &info = it->second;
880+
encoding = info.encodings[0];
881+
Value newOperand = getValueAs(operand.get(), encoding);
882+
op->setOperand(operand.getOperandNumber(), newOperand);
883+
}
884+
return true;
885+
}
886+
887+
return false;
888+
}
889+
834890
Operation *LayoutPropagation::rewriteOp(Operation *op) {
835891
opToDelete.insert(op);
836892
if (auto forOp = dyn_cast<scf::ForOp>(op))

0 commit comments

Comments
 (0)