Skip to content

Commit 1607e09

Browse files
authored
[Gluon] Infer slice encoding for SplitOp result (#7247)
Using a slice results in the same underlying register layout, but means that split->join can round-trip to infer the original layout. However, this comes at the cost of breaking join->split round-tripping.
1 parent 056ad7f commit 1607e09

File tree

5 files changed

+131
-31
lines changed

5 files changed

+131
-31
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "triton/Tools/LayoutUtils.h"
2121
#include "triton/Tools/LinearLayout.h"
2222
#include "triton/Tools/StrUtil.h"
23-
#include "triton/Tools/Sys/GetEnv.hpp"
2423
#include "llvm/ADT/SmallSet.h"
2524
#include "llvm/ADT/TypeSwitch.h"
2625
#include "llvm/Support/MathExtras.h"
@@ -442,6 +441,15 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
442441
return encoding;
443442
}
444443

444+
bool isSplitCompatible(MLIRContext *ctx, const LinearLayout &ll) {
445+
auto lastDim = ll.getNumOutDims() - 1;
446+
auto kReg = StringAttr::get(ctx, "register");
447+
auto kLastDim = StringAttr::get(ctx, "dim" + std::to_string(lastDim));
448+
auto sublayout =
449+
ll.sublayout({kReg}, {kLastDim}).removeZeroBasesAlongDim(kReg);
450+
return sublayout == LinearLayout::identity1D(2, kReg, kLastDim);
451+
}
452+
445453
LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
446454
LinearLayout &outLl, bool fwdInference, int axis,
447455
std::optional<Location> loc) {
@@ -2626,7 +2634,19 @@ struct TritonGPUInferLayoutInterface
26262634
inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
26272635
ArrayRef<int64_t> shape,
26282636
std::optional<Location> loc) const override {
2629-
if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
2637+
auto ctx = getContext();
2638+
if (auto enc = mlir::dyn_cast<SliceEncodingAttr>(srcEnc);
2639+
enc && enc.getDim() == shape.size()) {
2640+
SmallVector<int64_t> joinedShape(shape);
2641+
joinedShape.push_back(2);
2642+
auto parent = enc.getParent();
2643+
auto parentLL = toLinearLayout(joinedShape, parent);
2644+
2645+
if (isSplitCompatible(ctx, parentLL)) {
2646+
dstEnc = parent;
2647+
return success();
2648+
}
2649+
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
26302650
// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
26312651
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
26322652
// thread in the new dimension. The new dimension is the fastest running
@@ -2651,8 +2671,6 @@ struct TritonGPUInferLayoutInterface
26512671
return success();
26522672
}
26532673

2654-
auto ctx = getContext();
2655-
26562674
// Append dim to shape
26572675
auto ll = toLinearLayout(shape, srcEnc);
26582676
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
@@ -2729,7 +2747,6 @@ struct TritonGPUInferLayoutInterface
27292747
if (!result.succeeded()) {
27302748
return failure();
27312749
}
2732-
27332750
// Remove last dim from newLl (which should be 1)
27342751
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
27352752
dstShape.pop_back();

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class LayoutPropagation {
9898
// Return the mapped value in the given encoding. This will insert a convert
9999
// if the encoding is different than the encoding decided at resolve time.
100100
Value getValueAs(Value value, Attribute encoding);
101+
// Return the original value mapped to the new desired encoding.
102+
Value getRewrittenValue(Value value);
101103
// Dump the current stage of layout information.
102104
void dump();
103105

@@ -440,22 +442,25 @@ void LayoutPropagation::map(Value old, Value newV) {
440442
newV;
441443
}
442444

445+
Value LayoutPropagation::getRewrittenValue(Value value) {
446+
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
447+
if (!tensorType)
448+
return value;
449+
auto layoutIt = layouts.find(value);
450+
if (layoutIt == layouts.end()) {
451+
return value;
452+
}
453+
assert(layoutIt->second.encodings.size() == 1 &&
454+
"we should have resolved to a single encoding");
455+
Attribute encodingPicked = *(layoutIt->second.encodings.begin());
456+
if (encodingPicked == tensorType.getEncoding())
457+
return value;
458+
return rewriteMapping.at({value, encodingPicked});
459+
}
460+
443461
Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
444462
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
445-
Value rewrittenValue;
446-
auto layoutIt = layouts.find(value);
447-
if (layoutIt == layouts.end()) {
448-
rewrittenValue = value;
449-
} else {
450-
assert(layoutIt->second.encodings.size() == 1 &&
451-
"we should have resolved to a single encoding");
452-
Attribute encodingPicked = *(layoutIt->second.encodings.begin());
453-
if (encodingPicked == tensorType.getEncoding())
454-
rewrittenValue = value;
455-
else
456-
rewrittenValue = rewriteMapping[{value, encodingPicked}];
457-
}
458-
assert(rewrittenValue);
463+
Value rewrittenValue = getRewrittenValue(value);
459464
if (cast<RankedTensorType>(rewrittenValue.getType()).getEncoding() ==
460465
encoding)
461466
return rewrittenValue;
@@ -478,7 +483,19 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
478483

479484
Attribute operandEnc;
480485
if (op->getNumOperands() > 0) {
481-
operandEnc = inferSrcEncoding(op, encoding);
486+
for (auto operand : op->getOperands()) {
487+
auto ty =
488+
dyn_cast<RankedTensorType>(getRewrittenValue(operand).getType());
489+
if (!ty)
490+
continue;
491+
auto enc = ty.getEncoding();
492+
if (inferDstEncoding(op, enc) == encoding) {
493+
operandEnc = enc;
494+
break;
495+
}
496+
}
497+
if (!operandEnc)
498+
operandEnc = inferSrcEncoding(op, encoding);
482499
assert(operandEnc);
483500
}
484501

python/src/gluon_ir.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,19 @@ void init_gluon_ir(py::module &&m) {
346346
[](GluonOpBuilder &self, Type resultType, Value src) -> Value {
347347
return self.create<ttg::MemDescReinterpretOp>(resultType, src);
348348
})
349-
349+
.def("create_split",
350+
[](GluonOpBuilder &self, Value &a) -> py::tuple {
351+
auto argTy = cast<RankedTensorType>(a.getType());
352+
auto ctx = argTy.getContext();
353+
auto enc = ttg::SliceEncodingAttr::get(
354+
ctx, argTy.getRank() - 1,
355+
cast<ttg::DistributedEncodingTrait>(argTy.getEncoding()));
356+
auto resTy =
357+
RankedTensorType::get(ArrayRef(argTy.getShape()).drop_back(),
358+
argTy.getElementType(), enc);
359+
auto op = self.create<triton::SplitOp>(TypeRange{resTy, resTy}, a);
360+
return py::make_tuple(op->getResult(0), op->getResult(1));
361+
})
350362
.def("create_tmem_alloc",
351363
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
352364
return self.create<ttng::TMEMAllocOp>(resultTy, value);

python/test/gluon/test_frontend.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -881,10 +881,10 @@ def test_split_join():
881881
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 2], [32, 1], [4, 1], [1, 0])
882882
ttgl.static_assert(res.type.layout == expect_layout)
883883

884-
# CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, [[BLOCKED]]>
884+
# CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, #ttg.slice<{dim = 1, parent = [[BLOCKED1]]}>>
885885
c, d = ttgl.split(res)
886-
ttgl.static_assert(c.type.layout == layout)
887-
ttgl.static_assert(d.type.layout == layout)
886+
ttgl.static_assert(c.type.layout == ttgl.SliceLayout(1, expect_layout))
887+
ttgl.static_assert(d.type.layout == ttgl.SliceLayout(1, expect_layout))
888888

889889

890890
@filecheck_test
@@ -1022,3 +1022,39 @@ def test_async_copy(fresh_knobs):
10221022
} loc(#loc)
10231023
} loc(#loc)
10241024
""")
1025+
1026+
1027+
def test_split_join_subtile(fresh_knobs):
1028+
1029+
@gluon.jit
1030+
def kernel():
1031+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 128], [32, 1], [4, 1], [0, 1])
1032+
x = ttgl.full([128, 128], 1, ttgl.int32, layout=layout)
1033+
1034+
a, b = x.reshape([128, 2, 64]).permute([0, 2, 1]).split()
1035+
y = ttgl.join(a, b).permute([0, 2, 1]).reshape([128, 128])
1036+
_ = x + y
1037+
1038+
knobs.compilation.disable_line_info = True
1039+
h = kernel.warmup(grid=(1, ), sanitize_overflow=False)
1040+
expecttest.assert_expected_inline(
1041+
anonymize_ir(h.asm["source"]), """\
1042+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1043+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 2, 64], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 2, 1]}>
1044+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
1045+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
1046+
tt.func public @kernel() attributes {noinline = false} {
1047+
%c1_i32 = arith.constant 1 : i32 loc(#loc)
1048+
%cst = arith.constant dense<1> : tensor<128x128xi32, #blocked> loc(#loc)
1049+
%0 = tt.reshape %cst : tensor<128x128xi32, #blocked> -> tensor<128x2x64xi32, #blocked1> loc(#loc)
1050+
%1 = tt.trans %0 {order = array<i32: 0, 2, 1>} : tensor<128x2x64xi32, #blocked1> -> tensor<128x64x2xi32, #blocked2> loc(#loc)
1051+
%outLHS, %outRHS = tt.split %1 : tensor<128x64x2xi32, #blocked2> -> tensor<128x64xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> loc(#loc)
1052+
%2 = tt.join %outLHS, %outRHS : tensor<128x64xi32, #ttg.slice<{dim = 2, parent = #blocked2}>> -> tensor<128x64x2xi32, #blocked2> loc(#loc)
1053+
%3 = tt.trans %2 {order = array<i32: 0, 2, 1>} : tensor<128x64x2xi32, #blocked2> -> tensor<128x2x64xi32, #blocked1> loc(#loc)
1054+
%4 = tt.reshape %3 : tensor<128x2x64xi32, #blocked1> -> tensor<128x128xi32, #blocked> loc(#loc)
1055+
%5 = arith.addi %cst, %4 : tensor<128x128xi32, #blocked> loc(#loc)
1056+
tt.return loc(#loc)
1057+
} loc(#loc)
1058+
} loc(#loc)
1059+
#loc = loc(unknown)
1060+
""")

test/TritonGPU/combine.mlir

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3827,15 +3827,17 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
38273827

38283828
// -----
38293829

3830+
// CHECK-DAG: [[BLOCKED_OUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 2]
3831+
// CHECK-DAG: [[BLOCKED_JOIN:#.*]] = #ttg.blocked<{sizePerThread = [1, 2, 2]
3832+
// CHECK-DAG: [[BLOCKED_IN:#.*]] = #ttg.blocked<{sizePerThread = [1, 2]
38303833
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [2, 16, 1], warpsPerCTA = [1, 1, 1], order = [2, 1, 0]}>
38313834
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}>
38323835
#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 16], warpsPerCTA = [1, 1], order = [1, 0]}>
38333836
module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
3834-
// CHECK-LABEL: join_forward
38353837
tt.func @join_forward(%arg0: tensor<2x16xf32, #blocked2>) -> tensor<2x16x2xf32, #blocked> {
3836-
// CHECK: tt.join
3837-
// CHECK: ttg.convert_layout
3838-
// CHECK: tt.return
3838+
// CHECK: [[JOIN:%.*]] = tt.join %arg0, %arg0 : tensor<2x16xf32, [[BLOCKED_IN]]> -> tensor<2x16x2xf32, [[BLOCKED_JOIN]]>
3839+
// CHECK: [[RES:%.*]] = ttg.convert_layout [[JOIN]] : tensor<2x16x2xf32, [[BLOCKED_JOIN]]> -> tensor<2x16x2xf32, [[BLOCKED_OUT]]
3840+
// CHECK: tt.return [[RES]]
38393841
%0 = ttg.convert_layout %arg0 : tensor<2x16xf32, #blocked2> -> tensor<2x16xf32, #blocked1>
38403842
%1 = tt.join %0, %0 : tensor<2x16xf32, #blocked1> -> tensor<2x16x2xf32, #blocked>
38413843
tt.return %1 : tensor<2x16x2xf32, #blocked>
@@ -3848,15 +3850,31 @@ module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
38483850
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
38493851
#blocked2 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
38503852
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
3851-
// CHECK-LABEL: join_backward
3852-
tt.func @join_backward(%arg0: tensor<128x32xf16, #blocked>, %arg1: tensor<128x32xf16, #blocked>) -> tensor<128x32x2xf16, #blocked1> {
3853-
// CHECK: %[[JOIN:.*]] = tt.join
3853+
// CHECK-LABEL: join_backward_blocked
3854+
tt.func @join_backward_blocked(%arg0: tensor<128x32xf16, #blocked>, %arg1: tensor<128x32xf16, #blocked>) -> tensor<128x32x2xf16, #blocked1> {
3855+
// CHECK: %[[JOIN:.*]] = tt.join %arg0, %arg1
38543856
// CHECK: tt.return %[[JOIN]]
38553857
%0 = tt.join %arg0, %arg1 : tensor<128x32xf16, #blocked> -> tensor<128x32x2xf16, #blocked2>
38563858
%1 = ttg.convert_layout %0 : tensor<128x32x2xf16, #blocked2> -> tensor<128x32x2xf16, #blocked1>
38573859
tt.return %1 : tensor<128x32x2xf16, #blocked1>
38583860
}
38593861
}
3862+
3863+
// -----
3864+
3865+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
3866+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 32, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
3867+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
3868+
// CHECK-LABEL: join_backward_slice
3869+
tt.func @join_backward_slice(%arg0: tensor<128x32xf16, #ttg.slice<{dim=2, parent=#blocked1}>>, %arg1: tensor<128x32xf16, #ttg.slice<{dim=2, parent=#blocked1}>>) -> tensor<128x32x2xf16, #blocked1> {
3870+
// CHECK: %[[JOIN:.*]] = tt.join
3871+
// CHECK: tt.return %[[JOIN]]
3872+
%0 = tt.join %arg0, %arg1 : tensor<128x32xf16, #ttg.slice<{dim=2, parent=#blocked1}>> -> tensor<128x32x2xf16, #blocked2>
3873+
%1 = ttg.convert_layout %0 : tensor<128x32x2xf16, #blocked2> -> tensor<128x32x2xf16, #blocked1>
3874+
tt.return %1 : tensor<128x32x2xf16, #blocked1>
3875+
}
3876+
}
3877+
38603878
// -----
38613879

38623880
#linear = #ttg.linear<{register = [[0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 0], [32, 0]], block = []}>

0 commit comments

Comments
 (0)