Skip to content

Commit ca7b655

Browse files
Merge OpenAI Triton commit 6af4919 (#4435)
This PR change the Triton base from 9f88c7f to 6af4919 (Jun 3). Pass rate: 97.23%
2 parents 712dec1 + fb68276 commit ca7b655

File tree

72 files changed

+1784
-423
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+1784
-423
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,25 @@ class DialectInferLayoutInterface
3535

3636
virtual LogicalResult
3737
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int64_t> shape,
38-
ArrayRef<int32_t> order,
39-
Attribute &resultEncoding) const = 0;
38+
ArrayRef<int32_t> order, Attribute &resultEncoding,
39+
std::optional<Location> loc) const = 0;
4040

4141
virtual LogicalResult
4242
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
43-
Attribute &resultEncoding) const = 0;
43+
Attribute &resultEncoding,
44+
std::optional<Location> loc) const = 0;
4445

4546
virtual LogicalResult
4647
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
4748
Attribute &resultEncoding,
48-
std::optional<Location> location) const = 0;
49+
std::optional<Location> loc) const = 0;
4950

5051
// Note: This function only verifies the operand encoding. It doesn't infer
5152
// the result encoding.
5253
virtual LogicalResult
5354
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
5455
Attribute retEncoding,
55-
std::optional<Location> location) const = 0;
56+
std::optional<Location> loc) const = 0;
5657

5758
// Tries to compute the encoding for the result of a reshape operation that
5859
// makes the reshape a "nop", i.e. the same GPU threads contain the same
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,45 @@
11
#ifndef TRITON_IR_INTERFACES_H_
22
#define TRITON_IR_INTERFACES_H_
33

4+
#include "mlir/IR/DialectImplementation.h"
45
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Transforms/InliningUtils.h"
57

68
#define GET_TYPEDEF_CLASSES
79
#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc"
810

11+
namespace mlir::triton {
12+
13+
//===----------------------------------------------------------------------===//
14+
// TritonDialect Dialect Interfaces
15+
//===----------------------------------------------------------------------===//
16+
17+
struct TritonInlinerInterface : public DialectInlinerInterface {
18+
using DialectInlinerInterface::DialectInlinerInterface;
19+
20+
bool isLegalToInline(Operation *call, Operation *callable,
21+
bool wouldBeCloned) const final;
22+
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
23+
IRMapping &valueMapping) const final {
24+
return true;
25+
}
26+
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
27+
IRMapping &) const final {
28+
return true;
29+
}
30+
31+
//===--------------------------------------------------------------------===//
32+
// Transformation Hooks
33+
//===--------------------------------------------------------------------===//
34+
35+
/// Handle the given inlined terminator by replacing it with a new operation
36+
/// as necessary.
37+
void handleTerminator(Operation *op, Block *newDest) const final;
38+
/// Handle the given inlined terminator by replacing it with a new operation
39+
/// as necessary.
40+
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final;
41+
};
42+
43+
} // namespace mlir::triton
44+
945
#endif // TRITON_IR_TYPES_H_

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
460460
The compiler is still free to change it for better performance.
461461
}];
462462
let builders = [
463-
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "TypedValue<RankedTensorType>":$src)>
463+
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$src,
464+
CArg<"bool", "false">:$allowReorder)>
464465
];
465466

466467
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
@@ -728,9 +729,6 @@ def TT_ReduceOp: TT_Op<"reduce",
728729
let arguments = (ins Variadic<TT_Tensor>:$srcs, I32Attr:$axis);
729730
let results = (outs Variadic<TT_Type>:$result);
730731
let regions = (region SizedRegion<1>:$combineOp);
731-
let builders = [
732-
OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>,
733-
];
734732
let hasVerifier = 1;
735733
let hasRegionVerifier = 1;
736734
let extraClassDeclaration = [{

include/triton/Dialect/TritonGPU/IR/LayoutUtilities.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

include/triton/Dialect/TritonGPU/IR/LayoutUtility.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace mlir::triton::gpu {
55

6-
llvm::FailureOr<CTALayoutAttr>
7-
permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout, ArrayRef<int> order);
6+
CTALayoutAttr permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout,
7+
ArrayRef<int> order);
88

9-
}
9+
} // namespace mlir::triton::gpu

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,14 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
381381
];
382382

383383
let extraClassDeclaration = extraBaseClassDeclaration # [{
384+
unsigned getRank() const { return getCTAOrder().size(); }
384385
int32_t getAlignment() const;
385386
SmallVector<unsigned> getCTAsPerCGA() const;
386387
SmallVector<unsigned> getCTAOrder() const;
387388
SmallVector<unsigned> getCTASplitNum() const;
388389
}];
389390
let hasCustomAssemblyFormat = 1;
391+
let genVerifyDecl = 1;
390392
}
391393

392394
def NVMMASharedEncodingAttr :
@@ -450,6 +452,7 @@ def NVMMASharedEncodingAttr :
450452
];
451453

452454
let extraClassDeclaration = extraBaseClassDeclaration # [{
455+
unsigned getRank() const { return getCTAOrder().size(); }
453456
int32_t getAlignment() const;
454457
SmallVector<unsigned> getCTAsPerCGA() const;
455458
SmallVector<unsigned> getCTAOrder() const;
@@ -556,6 +559,7 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
556559
);
557560

558561
let extraClassDeclaration = extraBaseClassDeclaration # [{
562+
unsigned getRank() const { return getCTAOrder().size(); }
559563
int32_t getAlignment() const;
560564
SmallVector<unsigned> getCTAsPerCGA() const;
561565
SmallVector<unsigned> getCTAOrder() const;

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
252252
representing a transposed view of the buffer.
253253
}];
254254

255-
let arguments = (ins TTG_MemDescType:$src, Variadic<I32>:$order);
256-
257255
let arguments = (
258256
ins TTG_MemDescType:$src,
259257
DenseI32ArrayAttr:$order
@@ -284,6 +282,26 @@ def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
284282
let hasVerifier = 1;
285283
}
286284

285+
def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewTrait]> {
286+
let summary = "reinterpret a memory descriptor as a different type and shape";
287+
288+
let description = [{
289+
The `ttg.memdesc_reinterpret` operation reinterprets a memory descriptor
290+
as one with a different shape and element type. Because memory descriptors
291+
lack strides, this operation is only valid if the original memory descriptor
292+
is contiguous.
293+
}];
294+
295+
let arguments = (ins TTG_MemDescType:$src);
296+
let results = (outs TTG_MemDescType:$result);
297+
298+
let assemblyFormat = [{
299+
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
300+
}];
301+
302+
let hasVerifier = 1;
303+
}
304+
287305
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
288306
let summary = "Load a buffer from local memory into a distributed tensor";
289307

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,4 +360,18 @@ def TritonGPUCoalesceAsyncCopy: Pass<"tritongpu-coalesce-async-copy", "mlir::Mod
360360
"mlir::triton::TritonDialect"];
361361
}
362362

363+
def TritonGPUCanonicalize: Pass<"tritongpu-canonicalize"> {
364+
let summary = "reduced set of simplifications for TTGIR";
365+
366+
let description = [{
367+
The `tritongpu-canonicalize` pass applies a reduced set of simplification
368+
and canonicalization patterns to the module.
369+
}];
370+
let dependentDialects = [
371+
"mlir::arith::ArithDialect",
372+
"mlir::cf::ControlFlowDialect",
373+
"mlir::scf::SCFDialect",
374+
];
375+
}
376+
363377
#endif

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ scf::ForOp replaceForOpWithNewSignature(
141141
SmallVectorImpl<std::tuple<Value, Value>> &replacements);
142142
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
143143
ValueRange newIterOperands);
144-
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
145-
ValueRange newIterOperands);
144+
[[nodiscard]] scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop,
145+
ValueRange newIterOperands);
146146

147147
// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not
148148
// updated and needs to be updated separately for the loop to be correct.

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,27 @@ struct MemDescSubviewOpConversion
480480
return success();
481481
}
482482
};
483+
484+
struct MemDescReinterpretOpConversion
485+
: public ConvertOpToLLVMPattern<MemDescReinterpretOp> {
486+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
487+
488+
LogicalResult matchAndRewrite(MemDescReinterpretOp op, OpAdaptor adaptor,
489+
ConversionPatternRewriter &b) const override {
490+
Location loc = op.getLoc();
491+
MemDescType srcTy = op.getSrc().getType();
492+
MemDescType dstTy = op.getType();
493+
Type srcElemTy = getTypeConverter()->convertType(srcTy.getElementType());
494+
Type dstElemTy = getTypeConverter()->convertType(dstTy.getElementType());
495+
496+
auto smemObj =
497+
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), srcElemTy, b);
498+
SharedMemoryObject newObj(smemObj.getBase(), dstElemTy, dstTy.getRank(),
499+
loc, b);
500+
b.replaceOp(op, getStructFromSharedMemoryObject(loc, newObj, b));
501+
return success();
502+
}
503+
};
483504
} // namespace
484505

485506
void mlir::triton::populateViewOpToLLVMPatterns(
@@ -497,4 +518,5 @@ void mlir::triton::populateViewOpToLLVMPatterns(
497518
patterns.add<TransOpConversion>(typeConverter, benefit);
498519
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
499520
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
521+
patterns.add<MemDescReinterpretOpConversion>(typeConverter, benefit);
500522
}

0 commit comments

Comments
 (0)