Skip to content

Commit 70cc8ee

Browse files
Merge commit 'e4c7fe8f88f0b4392ab1c60751ae687ca0cec2f5'
2 parents b6485a7 + e4c7fe8 commit 70cc8ee

File tree

20 files changed

+194
-152
lines changed

20 files changed

+194
-152
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,25 @@ class DialectInferLayoutInterface
3535

3636
virtual LogicalResult
3737
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int64_t> shape,
38-
ArrayRef<int32_t> order,
39-
Attribute &resultEncoding) const = 0;
38+
ArrayRef<int32_t> order, Attribute &resultEncoding,
39+
std::optional<Location> loc) const = 0;
4040

4141
virtual LogicalResult
4242
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
43-
Attribute &resultEncoding) const = 0;
43+
Attribute &resultEncoding,
44+
std::optional<Location> loc) const = 0;
4445

4546
virtual LogicalResult
4647
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
4748
Attribute &resultEncoding,
48-
std::optional<Location> location) const = 0;
49+
std::optional<Location> loc) const = 0;
4950

5051
// Note: This function only verifies the operand encoding. It doesn't infer
5152
// the result encoding.
5253
virtual LogicalResult
5354
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
5455
Attribute retEncoding,
55-
std::optional<Location> location) const = 0;
56+
std::optional<Location> loc) const = 0;
5657

5758
// Tries to compute the encoding for the result of a reshape operation that
5859
// makes the reshape a "nop", i.e. the same GPU threads contain the same

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
460460
The compiler is still free to change it for better performance.
461461
}];
462462
let builders = [
463-
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "TypedValue<RankedTensorType>":$src)>
463+
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$src,
464+
CArg<"bool", "false">:$allowReorder)>
464465
];
465466

466467
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
@@ -728,9 +729,6 @@ def TT_ReduceOp: TT_Op<"reduce",
728729
let arguments = (ins Variadic<TT_Tensor>:$srcs, I32Attr:$axis);
729730
let results = (outs Variadic<TT_Type>:$result);
730731
let regions = (region SizedRegion<1>:$combineOp);
731-
let builders = [
732-
OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>,
733-
];
734732
let hasVerifier = 1;
735733
let hasRegionVerifier = 1;
736734
let extraClassDeclaration = [{

include/triton/Dialect/TritonGPU/IR/LayoutUtilities.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/triton/Dialect/TritonGPU/IR/LayoutUtility.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace mlir::triton::gpu {
55

6-
llvm::FailureOr<CTALayoutAttr>
7-
permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout, ArrayRef<int> order);
6+
CTALayoutAttr permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout,
7+
ArrayRef<int> order);
88

9-
}
9+
} // namespace mlir::triton::gpu

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,14 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
381381
];
382382

383383
let extraClassDeclaration = extraBaseClassDeclaration # [{
384+
unsigned getRank() const { return getCTAOrder().size(); }
384385
int32_t getAlignment() const;
385386
SmallVector<unsigned> getCTAsPerCGA() const;
386387
SmallVector<unsigned> getCTAOrder() const;
387388
SmallVector<unsigned> getCTASplitNum() const;
388389
}];
389390
let hasCustomAssemblyFormat = 1;
391+
let genVerifyDecl = 1;
390392
}
391393

392394
def NVMMASharedEncodingAttr :
@@ -450,6 +452,7 @@ def NVMMASharedEncodingAttr :
450452
];
451453

452454
let extraClassDeclaration = extraBaseClassDeclaration # [{
455+
unsigned getRank() const { return getCTAOrder().size(); }
453456
int32_t getAlignment() const;
454457
SmallVector<unsigned> getCTAsPerCGA() const;
455458
SmallVector<unsigned> getCTAOrder() const;
@@ -556,6 +559,7 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
556559
);
557560

558561
let extraClassDeclaration = extraBaseClassDeclaration # [{
562+
unsigned getRank() const { return getCTAOrder().size(); }
559563
int32_t getAlignment() const;
560564
SmallVector<unsigned> getCTAsPerCGA() const;
561565
SmallVector<unsigned> getCTAOrder() const;

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,10 @@ LogicalResult TransOp::verify() {
231231
return success();
232232
}
233233

234-
LogicalResult TransOp::inferReturnTypes(
235-
MLIRContext *context, std::optional<Location> location,
236-
TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
234+
LogicalResult
235+
TransOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
236+
TransOp::Adaptor adaptor,
237+
SmallVectorImpl<Type> &inferredReturnTypes) {
237238

238239
// type is the same as the input
239240
auto argTy = cast<RankedTensorType>(adaptor.getSrc().getType());
@@ -247,9 +248,8 @@ LogicalResult TransOp::inferReturnTypes(
247248
if (argEncoding) {
248249
Dialect &dialect = argEncoding.getDialect();
249250
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
250-
if (inferLayoutInterface
251-
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
252-
.failed()) {
251+
if (failed(inferLayoutInterface->inferTransOpEncoding(
252+
argEncoding, shape, order, retEncoding, loc))) {
253253
return failure();
254254
}
255255
}
@@ -389,7 +389,8 @@ LogicalResult MakeRangeOp::verify() {
389389

390390
//-- ReduceOp --
391391
static LogicalResult
392-
inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
392+
inferReduceReturnShape(std::optional<Location> loc, RankedTensorType argTy,
393+
Type retEltTy, int axis,
393394
SmallVectorImpl<Type> &inferredReturnTypes) {
394395
auto retShape = argTy.getShape().vec();
395396
retShape.erase(retShape.begin() + axis);
@@ -404,10 +405,8 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
404405
if (argEncoding) {
405406
Dialect &dialect = argEncoding.getDialect();
406407
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
407-
if (inferLayoutInterface
408-
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
409-
.failed()) {
410-
llvm::report_fatal_error("failed to infer layout for ReduceOp");
408+
if (failed(inferLayoutInterface->inferReduceOpEncoding(
409+
argEncoding, axis, retEncoding, loc))) {
411410
return failure();
412411
}
413412
}
@@ -418,29 +417,18 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
418417
return success();
419418
}
420419

421-
void ReduceOp::build(OpBuilder &builder, OperationState &state,
422-
ValueRange operands, int axis) {
423-
SmallVector<Type> inferredReturnTypes;
424-
for (unsigned i = 0; i < operands.size(); ++i) {
425-
auto argTy = cast<RankedTensorType>(operands[i].getType());
426-
auto retEltTy = argTy.getElementType();
427-
(void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes);
428-
}
429-
430-
ReduceOp::build(builder, state, inferredReturnTypes, operands, axis);
431-
}
432-
433-
LogicalResult ReduceOp::inferReturnTypes(
434-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
435-
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
436-
SmallVectorImpl<Type> &inferredReturnTypes) {
420+
LogicalResult
421+
ReduceOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
422+
ValueRange operands, DictionaryAttr attributes,
423+
OpaqueProperties properties, RegionRange regions,
424+
SmallVectorImpl<Type> &inferredReturnTypes) {
437425
Properties *prop = properties.as<Properties *>();
438426
int axis = prop->axis.getInt();
439427
for (auto arg : operands) {
440428
auto argTy = cast<RankedTensorType>(arg.getType());
441429
auto retEltTy = argTy.getElementType();
442-
if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes)
443-
.failed()) {
430+
if (failed(inferReduceReturnShape(loc, argTy, retEltTy, axis,
431+
inferredReturnTypes))) {
444432
return failure();
445433
}
446434
}
@@ -636,9 +624,8 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
636624
if (argEncoding) {
637625
Dialect &dialect = argEncoding.getDialect();
638626
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
639-
if (inferLayoutInterface
640-
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
641-
.failed())
627+
if (failed(inferLayoutInterface->inferExpandDimsOpEncoding(
628+
argEncoding, axis, retEncoding, loc)))
642629
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
643630
}
644631
// create type
@@ -674,10 +661,10 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
674661
// Infer the encoding of the new expand op, if encodings are present.
675662
Attribute newExpandEnc;
676663
if (auto srcEnc = srcTy.getEncoding()) {
677-
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
678-
->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc,
679-
op.getLoc())
680-
.failed()) {
664+
Dialect &dialect = srcEnc.getDialect();
665+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
666+
if (failed(inferLayoutInterface->inferExpandDimsOpEncoding(
667+
srcEnc, op.getAxis(), newExpandEnc, op.getLoc()))) {
681668
return emitOptionalError(op.getLoc(),
682669
"failed to infer layout for ExpandDimsOp");
683670
}
@@ -719,9 +706,8 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
719706
//-- ReshapeOp --
720707

721708
void ReshapeOp::build(OpBuilder &builder, OperationState &state,
722-
ArrayRef<int64_t> shape,
723-
TypedValue<RankedTensorType> src) {
724-
auto srcTy = src.getType();
709+
ArrayRef<int64_t> shape, Value src, bool allowReorder) {
710+
auto srcTy = cast<RankedTensorType>(src.getType());
725711
auto srcEnc = srcTy.getEncoding();
726712
Attribute dstEnc;
727713
if (srcEnc) {
@@ -731,7 +717,7 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &state,
731717
assert(succeeded(result));
732718
}
733719
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
734-
build(builder, state, dstTy, src);
720+
build(builder, state, dstTy, src, allowReorder);
735721
}
736722

737723
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
@@ -794,14 +780,14 @@ LogicalResult ReshapeOp::verify() {
794780
// Check that we can infer the dst encoding from the src encoding
795781
// and that the inferred dst encoding is the same as the given dst encoding
796782
Attribute inferredDstEnc;
797-
auto result =
798-
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
799-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
800-
inferredDstEnc, getLoc());
801-
assert(succeeded(result));
802-
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
803-
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
804-
getLoc());
783+
auto layoutInterface =
784+
cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
785+
auto result = layoutInterface->inferReshapeOpEncoding(
786+
srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc());
787+
if (failed(result))
788+
return failure();
789+
return layoutInterface->verifyLayoutsAreEqual(
790+
dstTy.getShape(), inferredDstEnc, dstEnc, getLoc());
805791
}
806792

807793
//-- FpToFpOp --
@@ -1092,11 +1078,10 @@ void JoinOp::build(OpBuilder &builder, OperationState &state, Value lhs,
10921078
Attribute srcEnc = lhsTy.getEncoding();
10931079
Attribute retEnc;
10941080
if (srcEnc) {
1095-
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1096-
->inferDefaultJoinOpEncoding(srcEnc, retEnc, lhsTy.getShape(),
1097-
/*loc=*/std::nullopt)
1098-
.failed()) {
1099-
assert(false && "failed to infer join encoding");
1081+
if (failed(cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1082+
->inferDefaultJoinOpEncoding(
1083+
srcEnc, retEnc, lhsTy.getShape(), state.location))) {
1084+
llvm_unreachable("failed to infer join encoding");
11001085
}
11011086
}
11021087
auto retTy = RankedTensorType::get(retShape, lhsTy.getElementType(), retEnc);

0 commit comments

Comments
 (0)