Skip to content

Commit c4c8bac

Browse files
authored
[BACKEND] Allow mixed of linear layout and legacy in split/join (#6028)
Relax the verifier of split/join to allow linear and legacy layouts to be mixed.
1 parent e24d693 commit c4c8bac

File tree

6 files changed

+32
-25
lines changed

6 files changed

+32
-25
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
505505

506506
def TT_JoinOp : TT_Op<"join", [
507507
NoMemoryEffect, SameTypeOperands,
508-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
508+
InferTypeOpWithLayoutEquivalence,
509509
]> {
510510
let summary = "join two tensors along a new, minor dimension";
511511
let description = [{
@@ -523,7 +523,7 @@ def TT_JoinOp : TT_Op<"join", [
523523

524524
def TT_SplitOp : TT_Op<"split", [
525525
NoMemoryEffect,
526-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
526+
InferTypeOpWithLayoutEquivalence,
527527
TypesMatchWith<"outLHS and outRHS types match",
528528
"outLHS", "outRHS", "$_self">,
529529
]> {

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,17 +1027,9 @@ LogicalResult ReturnOp::verify() {
10271027
// -- JoinOp --
10281028
LogicalResult
10291029
JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
1030-
ValueRange operands, DictionaryAttr attributes,
1031-
OpaqueProperties properties, RegionRange regions,
1030+
JoinOp::Adaptor adaptor,
10321031
SmallVectorImpl<Type> &inferredReturnTypes) {
1033-
// These should have been checked by tablegen-generated code.
1034-
assert(operands.size() == 2);
1035-
assert(operands[0].getType() == operands[1].getType());
1036-
assert(isa<RankedTensorType>(operands[0].getType()));
1037-
assert(isa<RankedTensorType>(operands[1].getType()));
1038-
1039-
Value lhs = operands[0];
1040-
auto srcTy = cast<RankedTensorType>(lhs.getType());
1032+
auto srcTy = cast<RankedTensorType>(adaptor.getLhs().getType());
10411033

10421034
SmallVector<int64_t> retShape(srcTy.getShape());
10431035
retShape.push_back(2);
@@ -1058,15 +1050,9 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
10581050

10591051
// -- SplitOp --
10601052
LogicalResult SplitOp::inferReturnTypes(
1061-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
1062-
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1063-
SmallVectorImpl<Type> &inferredReturnTypes) {
1064-
// These should have been checked by tablegen-generated code.
1065-
assert(operands.size() == 1);
1066-
assert(isa<RankedTensorType>(operands[0].getType()));
1067-
1068-
Value src = operands[0];
1069-
auto srcTy = cast<RankedTensorType>(src.getType());
1053+
MLIRContext *context, std::optional<Location> location,
1054+
SplitOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1055+
auto srcTy = cast<RankedTensorType>(adaptor.getSrc().getType());
10701056
auto srcShape = srcTy.getShape();
10711057

10721058
if (srcShape.empty() || srcShape.back() != 2) {

lib/Dialect/Triton/IR/Traits.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) {
2121
auto shapeB = tensorTypeB.getShape();
2222
if (shapeA != shapeB)
2323
return failure();
24-
24+
if (tensorTypeA.getElementType() != tensorTypeB.getElementType())
25+
return failure();
2526
// If there's no encoding or the encodings are the same
2627
if (encodingA == encodingB)
2728
return success();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,6 +2802,8 @@ struct TritonGPUInferLayoutInterface
28022802
if (expected == got) {
28032803
return success();
28042804
}
2805+
if (!expected || !got)
2806+
return failure();
28052807
// Check whether the encodings are structurally the same.
28062808
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
28072809
auto gotLL = triton::gpu::toLinearLayout(shape, got);

test/Triton/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ tt.func public @fn(%arg0: tensor<32xf32, #blocked>) {
170170
// -----
171171

172172
// Bad order; should be [1,0]
173-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
173+
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
174174
#blocked1 = #ttg.blocked<{sizePerThread = [1,2], threadsPerWarp = [32,1], warpsPerCTA = [1,1], order = [0,1]}>
175175
module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
176176
tt.func public @fn(%arg0: tensor<32xf32, #blocked>) {
177-
// expected-error @+2 {{order}}
177+
// expected-error @+2 {{incompatible with return type(s) of operation}}
178178
// expected-error @+1 {{op failed to infer returned types}}
179179
%a = tt.join %arg0, %arg0 : tensor<32xf32, #blocked> -> tensor<32x2xf32, #blocked1>
180180
tt.return
@@ -215,7 +215,7 @@ tt.func public @fn(%arg0: tensor<2xf32>) {
215215

216216
// -----
217217

218-
#blocked = #ttg.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}>
218+
#blocked = #ttg.blocked<{sizePerThread = [1,2,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}>
219219
// Bad order, should be [1,0].
220220
#blocked1 = #ttg.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}>
221221

test/TritonGPU/ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,21 @@ tt.func @function_no_scope() {
189189
}
190190

191191
}
192+
193+
// -----
194+
195+
// CHECK-DAG: [[$BLOCKED:#.*]] = #ttg.blocked
196+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
197+
// CHECK-DAG: [[$LINEAR:#.*]] = #ttg.linear
198+
#linear = #ttg.linear<{register = [[0, 1], [16, 0], [32, 0], [64, 0]], lane = [[0, 0], [0, 0], [0, 0], [1, 0], [2, 0]], warp = [[4, 0], [8, 0]], block = []}>
199+
200+
module attributes {"ttg.num-warps" = 4 : i32} {
201+
// CHECK-LABEL: @split_join_linear_mix
202+
tt.func @split_join_linear_mix(%arg: tensor<128x2xf32, #linear>) attributes {"ttg.num-warps" = 8 : i32} {
203+
// CHECK-NEXT: tt.split %{{.*}} : tensor<128x2xf32, [[$LINEAR]]> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = [[$BLOCKED]]}>>
204+
%lhs, %rhs = tt.split %arg : tensor<128x2xf32, #linear> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
205+
// CHECK-NEXT: tt.join %{{.*}}, %{{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = [[$BLOCKED]]}>> -> tensor<128x2xf32, [[$LINEAR]]>
206+
%j = tt.join %lhs, %rhs : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x2xf32, #linear>
207+
tt.return
208+
}
209+
}

0 commit comments

Comments
 (0)