Skip to content

Commit febe2a1

Browse files
authored
[ScaledDot] Remove SinkTranspose from ScaledDot (triton-lang#5653)
We remove the SinkTranspose. This was initially put in place to circunvent the issue of not being able to propagate the MMA layout past a transpose. This was landed in triton-lang#5403 so this pass not necessary anymore. The next step will be to get rid of the `transposeDot` part of the pass and instead integrate it into a different more generic pass that checks whether a dot operand inputs should be transposed to take advantage of the reg x shmem MMAv3 op.
1 parent 41ecd1c commit febe2a1

File tree

3 files changed

+49
-146
lines changed

3 files changed

+49
-146
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -656,110 +656,8 @@ class DecomposeScaledBlocked
656656
}
657657
};
658658

659-
static void updateValueType(Value v, Attribute encoding,
660-
ArrayRef<int64_t> shape) {
661-
auto tensorType = cast<RankedTensorType>(v.getType());
662-
auto newType =
663-
RankedTensorType::get(shape, tensorType.getElementType(), encoding);
664-
v.setType(newType);
665-
}
666-
667-
static TransOp updateUsers(Value result, const SetVector<Operation *> &slice) {
668-
TransOp transOp;
669-
if (llvm::any_of(result.getUsers(),
670-
[&](Operation *user) { return slice.count(user) == 0; })) {
671-
OpBuilder builder(result.getContext());
672-
builder.setInsertionPointAfterValue(result);
673-
transOp =
674-
builder.create<TransOp>(result.getLoc(), result, ArrayRef({1, 0}));
675-
result.replaceUsesWithIf(transOp.getResult(), [&](OpOperand &operand) {
676-
return operand.getOwner() != transOp.getOperation() &&
677-
slice.count(operand.getOwner()) == 0;
678-
});
679-
}
680-
return transOp;
681-
}
682-
683-
// Sync the transpose in the IR, this is done to avoid generating convert layout
684-
// when we have a transpose right after a dot as mma layout cannot be propagated
685-
// through transpose op. Once we have layouts that can represent transposed MMA
686-
// we can remove this transformation.
687-
static void sinkTransposeOp(TransOp input) {
688-
SmallVector<TransOp> queue = {input};
689-
while (!queue.empty()) {
690-
TransOp transOp = queue.back();
691-
Value currentValue = transOp.getResult();
692-
queue.pop_back();
693-
mlir::ForwardSliceOptions options;
694-
options.filter = [](Operation *op) {
695-
if (op->hasTrait<OpTrait::Elementwise>() && op->getNumOperands() == 1)
696-
return true;
697-
if (isa<scf::YieldOp>(op))
698-
return isa<scf::ForOp>(op->getParentOp());
699-
if (isa<ConvertLayoutOp>(op))
700-
return true;
701-
return false;
702-
};
703-
SetVector<Operation *> slice;
704-
mlir::getForwardSlice(currentValue, &slice, options);
705-
for (Operation *op : slice) {
706-
if (op->hasTrait<OpTrait::Elementwise>()) {
707-
// Update users of transpose op.
708-
if (op->getOperand(0) == transOp.getResult())
709-
op->setOperand(0, transOp.getOperand());
710-
// Update the type of the result.
711-
for (Value result : op->getResults()) {
712-
auto srcType = cast<RankedTensorType>(op->getOperand(0).getType());
713-
updateValueType(result, srcType.getEncoding(), srcType.getShape());
714-
updateUsers(result, slice);
715-
}
716-
continue;
717-
}
718-
if (auto cvtOp = dyn_cast<ConvertLayoutOp>(op)) {
719-
// Update users of transpose op.
720-
if (op->getOperand(0) == transOp.getResult())
721-
op->setOperand(0, transOp.getOperand());
722-
auto resultEncoding = cvtOp.getType().getEncoding();
723-
auto newDstEncoding = inferSrcEncoding(transOp, resultEncoding);
724-
assert(newDstEncoding);
725-
auto srcType = cast<RankedTensorType>(cvtOp.getOperand().getType());
726-
updateValueType(cvtOp.getResult(), newDstEncoding, srcType.getShape());
727-
updateUsers(cvtOp.getResult(), slice);
728-
continue;
729-
}
730-
assert(isa<scf::YieldOp>(op));
731-
auto forOp = dyn_cast<scf::ForOp>(op->getParentOp());
732-
assert(forOp);
733-
for (OpOperand &operand : op->getOpOperands()) {
734-
Operation *def = operand.get().getDefiningOp();
735-
if (def && (slice.count(def)) || def == transOp.getOperation()) {
736-
if (def == transOp.getOperation())
737-
operand.set(transOp.getOperand());
738-
Type newType = operand.get().getType();
739-
forOp.getResult(operand.getOperandNumber()).setType(newType);
740-
TransOp retTrans =
741-
updateUsers(forOp.getResult(operand.getOperandNumber()), slice);
742-
// Recursively try to propagate the new transpose inserted.
743-
if (retTrans)
744-
queue.push_back(retTrans);
745-
forOp.getRegionIterArg(operand.getOperandNumber()).setType(newType);
746-
TransOp argTrans = updateUsers(
747-
forOp.getRegionIterArg(operand.getOperandNumber()), slice);
748-
if (argTrans)
749-
queue.push_back(argTrans);
750-
OpBuilder builder(forOp);
751-
OpOperand &init = forOp.getInitsMutable()[operand.getOperandNumber()];
752-
Value initTranspose = builder.create<TransOp>(
753-
forOp.getLoc(), init.get(), ArrayRef({1, 0}));
754-
init.set(initTranspose);
755-
}
756-
}
757-
}
758-
}
759-
}
760-
761659
// Transpose scaled_dot ops that have a scale on lhs.
762-
static Operation *transposeDotOp(DotScaledOp dotOp) {
660+
static void transposeDotOp(DotScaledOp dotOp) {
763661
OpBuilder builder(dotOp);
764662
Value lhs = dotOp.getLhs();
765663
std::array<int, 2> transOrder = {1, 0};
@@ -776,7 +674,6 @@ static Operation *transposeDotOp(DotScaledOp dotOp) {
776674
builder.create<TransOp>(result.getLoc(), result, transOrder);
777675
dotOp.replaceAllUsesWith(transposedResult);
778676
dotOp.erase();
779-
return transposedResult;
780677
}
781678

782679
static void transposeDots(ModuleOp m) {
@@ -787,14 +684,8 @@ static void transposeDots(ModuleOp m) {
787684
if (dotOp.getLhsScale() == nullptr && dotOp.getRhsScale() != nullptr)
788685
toTranspose.push_back(dotOp);
789686
});
790-
SmallVector<Operation *> transposes;
791687
for (DotScaledOp dotOp : toTranspose) {
792-
Operation *transpose = transposeDotOp(dotOp);
793-
transposes.push_back(transpose);
794-
}
795-
796-
for (Operation *transpose : transposes) {
797-
sinkTransposeOp(cast<TransOp>(transpose));
688+
transposeDotOp(dotOp);
798689
}
799690
}
800691

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -226,38 +226,3 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
226226
tt.return %result : tensor<128x128xf32, #blocked>
227227
}
228228
}
229-
230-
// -----
231-
232-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
233-
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
234-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
235-
#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
236-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
237-
// CHECK-LABEL: dot_scale_transpose
238-
tt.func public @dot_scale_transpose(%arg0: tensor<128x64xf8E4M3FN, #blocked>, %arg1: tensor<32x32xi8, #blocked1>, %arg2: tensor<32x2xi8, #blocked2>, %arg3: tensor<128x32x!tt.ptr<bf16>, #blocked3>) {
239-
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1>
240-
%c1_i32 = arith.constant 1 : i32
241-
%c100_i32 = arith.constant 100 : i32
242-
%c0_i32 = arith.constant 0 : i32
243-
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked3>
244-
%cst_1 = arith.constant dense<2> : tensor<32x1xi32, #blocked2>
245-
// CHECK: scf.for
246-
%0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 {
247-
// CHECK-DAG: tt.trans %{{.*}} {order = array<i32: 1, 0>} : tensor<128x64xf8E4M3FN, #{{.*}}> -> tensor<64x128xf8E4M3FN, #{{.*}}>
248-
// CHECK-DAG: tt.trans %a{{.*}} {order = array<i32: 1, 0>} : tensor<32x32xi8, #{{.*}}> -> tensor<32x32xi8, #{{.*}}>
249-
%3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 {fastMath = false}: tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1>
250-
// CHECK: tt.dot
251-
// CHECK-NOT: tt.trans
252-
// CHECK: scf.yield
253-
scf.yield %3 : tensor<128x32xf32, #blocked1>
254-
}
255-
// CHECK: arith.truncf
256-
// CHECK: ttg.convert_layout
257-
// CHECK: tt.trans
258-
%1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1>
259-
%2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3>
260-
tt.store %arg3, %2 : tensor<128x32x!tt.ptr<bf16>, #blocked3>
261-
tt.return
262-
}
263-
}

test/TritonGPU/combine.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2848,3 +2848,50 @@ tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi
28482848
}
28492849

28502850
}
2851+
2852+
// -----
2853+
2854+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
2855+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2856+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
2857+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
2858+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
2859+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
2860+
#linear = #ttg.linear<{register = [[16, 0]], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
2861+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
2862+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2863+
2864+
// Test that after dot_scaled with rhs scales is decomposed, we are able to get rid of the redundant convert_layout
2865+
// CHECK-LABEL: dot_scale_transpose
2866+
tt.func public @dot_scale_transpose(%arg0: tensor<128x64xf8E4M3FN, #blocked>, %arg1: tensor<32x32xi8, #blocked1>, %arg2: tensor<32x2xi8, #blocked2>, %arg3: tensor<128x32x!tt.ptr<bf16>, #blocked3>) {
2867+
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked1>
2868+
%c1_i32 = arith.constant 1 : i32
2869+
%c100_i32 = arith.constant 100 : i32
2870+
%c0_i32 = arith.constant 0 : i32
2871+
%0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 {
2872+
%3 = tt.trans %arg0 {order = array<i32: 1, 0>} : tensor<128x64xf8E4M3FN, #blocked> -> tensor<64x128xf8E4M3FN, #blocked4>
2873+
%4 = tt.trans %arg1 {order = array<i32: 1, 0>} : tensor<32x32xi8, #blocked1> -> tensor<32x32xi8, #blocked5>
2874+
%5 = tt.trans %arg5 {order = array<i32: 1, 0>} : tensor<128x32xf32, #blocked1> -> tensor<32x128xf32, #blocked5>
2875+
%6 = ttg.convert_layout %5 : tensor<32x128xf32, #blocked5> -> tensor<32x128xf32, #mma>
2876+
%7 = ttg.convert_layout %4 : tensor<32x32xi8, #blocked5> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
2877+
%8 = ttg.convert_layout %arg2 : tensor<32x2xi8, #blocked2> -> tensor<32x2xi8, #linear>
2878+
%9 = ttg.upcast_mxfp %7, %8 fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
2879+
%10 = ttg.convert_layout %3 : tensor<64x128xf8E4M3FN, #blocked4> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
2880+
%11 = tt.fp_to_fp %10 : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
2881+
%12 = tt.dot %9, %11, %6 : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x128xf32, #mma>
2882+
// CHECK: tt.dot
2883+
// CHECK-NOT: ttg.convert_layout
2884+
// CHECK: scf.yield
2885+
%13 = ttg.convert_layout %12 : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked5>
2886+
%14 = tt.trans %13 {order = array<i32: 1, 0>} : tensor<32x128xf32, #blocked5> -> tensor<128x32xf32, #blocked1>
2887+
scf.yield %14 : tensor<128x32xf32, #blocked1>
2888+
}
2889+
// CHECK: arith.truncf
2890+
// CHECK-NEXT: ttg.convert_layout
2891+
// CHECK-NEXT: tt.store
2892+
%1 = arith.truncf %0 : tensor<128x32xf32, #blocked1> to tensor<128x32xbf16, #blocked1>
2893+
%2 = ttg.convert_layout %1 : tensor<128x32xbf16, #blocked1> -> tensor<128x32xbf16, #blocked3>
2894+
tt.store %arg3, %2 : tensor<128x32x!tt.ptr<bf16>, #blocked3>
2895+
tt.return
2896+
}
2897+
}

0 commit comments

Comments
 (0)