Skip to content

Commit ee9dc3a

Browse files
Merge commit '6c72dbe84d384cebb476bf18755d374ff928ff06'
2 parents 6e6cff9 + 6c72dbe commit ee9dc3a

File tree

37 files changed

+1026
-318
lines changed

37 files changed

+1026
-318
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipelineV2();
8989
mlir::registerTritonAMDGPUCanonicalizePointers();
9090
mlir::registerTritonAMDGPUConvertToBufferOps();
91+
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92+
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9193

9294
// TODO: register Triton & TritonGPU passes
9395
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

docs/python-api/triton-semantics.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ The algorithm is as follows:
1414

1515
2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32``
1616

17-
3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32``
18-
19-
3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``.
17+
3. **Prefer float16** If both tensors are of the same width and signedness but different dtypes (``float16`` and ``bfloat16`` or different ``fp8`` types), they are both promoted to ``float16``. ``(float16, bfloat16) -> float16``
2018

2119
4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32``
2220

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30+
struct BackendCallbacks {
31+
/**
32+
* A backend-specific callback for appending auxiliary data during
33+
* `LocalStoreOp` conversion.
34+
*
35+
* @param[in] op The reference to the re-written `LocalStoreOp`.
36+
* @param[in] count The number of issued LLVM instructions.
37+
* @param[in] type The input type of issued LLVM instructions.
38+
*/
39+
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40+
Type llvmOpType)>
41+
localStoreOpConversion = nullptr;
42+
};
43+
3044
void populateElementwiseOpToLLVMPatterns(
3145
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3246
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
3347
PatternBenefit benefit);
3448

35-
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36-
const TargetInfoBase &targetInfo,
37-
RewritePatternSet &patterns,
38-
PatternBenefit benefit);
49+
// The given callback is invoked at the end of a successful rewrite. The
50+
// callback receives 1) the current source op, 2) the number of issued LLVM
51+
// instructions and 3) their input types. Each MLIR backend can provide a
52+
// callback and, thus, handle backend-specific behaviors.
53+
void populateMemoryOpToLLVMPattern(
54+
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55+
RewritePatternSet &patterns, PatternBenefit benefit,
56+
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
3957

4058
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
4159
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370-
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371-
Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter,
1373-
const TargetInfoBase &target);
1369+
void storeDistributedToShared(
1370+
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371+
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373+
std::pair<size_t, Type> *const llvmOpCount = nullptr);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class CoarseSchedule {
100100
std::vector<std::pair<Operation *, unsigned>>
101101
createFinalSchedule(scf::ForOp forOp);
102102
void dump();
103+
bool empty() { return opToStageAndCluster.size() == 0; }
103104
};
104105

105106
} // namespace triton

lib/Analysis/Allocation.cpp

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
#include <limits>
55
#include <numeric>
66

7-
#include "mlir/Analysis/DataFlowFramework.h"
87
#include "mlir/Analysis/Liveness.h"
9-
#include "mlir/Analysis/SliceAnalysis.h"
108
#include "mlir/Dialect/Tensor/IR/Tensor.h"
119
#include "mlir/Support/LLVM.h"
1210
#include "triton/Analysis/Alias.h"
@@ -15,19 +13,6 @@
1513
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1614
#include "llvm/ADT/SmallVector.h"
1715

18-
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
19-
using ::mlir::triton::gpu::BlockedEncodingAttr;
20-
using ::mlir::triton::gpu::DotOperandEncodingAttr;
21-
using ::mlir::triton::gpu::getContigPerThread;
22-
using ::mlir::triton::gpu::getOrder;
23-
using ::mlir::triton::gpu::getShapePerCTA;
24-
using ::mlir::triton::gpu::getShapePerCTATile;
25-
using ::mlir::triton::gpu::getSizePerThread;
26-
using ::mlir::triton::gpu::getUniqueContigPerThread;
27-
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
28-
using ::mlir::triton::gpu::SharedEncodingAttr;
29-
using ::mlir::triton::gpu::SliceEncodingAttr;
30-
3116
namespace mlir {
3217

3318
//===----------------------------------------------------------------------===//
@@ -38,27 +23,6 @@ namespace triton {
3823
// Bitwidth of pointers
3924
constexpr int kPtrBitWidth = 64;
4025

41-
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
42-
getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
43-
auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout);
44-
auto srcDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(srcLayout);
45-
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
46-
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);
47-
48-
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
49-
!srcMmaLayout.isHopper()) &&
50-
"mma -> mma layout conversion is only supported on Ampere");
51-
52-
// mma or dot layout does not have an order, so the order depends on the
53-
// layout of the other operand.
54-
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
55-
: getOrder(srcLayout);
56-
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
57-
: getOrder(dstLayout);
58-
59-
return {inOrd, outOrd};
60-
}
61-
6226
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6327
RankedTensorType dstTy) {
6428
Attribute srcLayout = srcTy.getEncoding();
@@ -70,15 +34,17 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
7034

7135
if (shouldUseDistSmem(srcLayout, dstLayout)) {
7236
// TODO: padding to avoid bank conflicts
73-
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
37+
return convertType<unsigned, int64_t>(gpu::getShapePerCTA(srcTy));
7438
}
7539

7640
assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()");
7741

78-
auto srcShapePerCTA = getShapePerCTA(srcTy);
79-
auto dstShapePerCTA = getShapePerCTA(dstTy);
80-
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
81-
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
42+
auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
43+
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
44+
auto srcShapePerCTATile =
45+
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
46+
auto dstShapePerCTATile =
47+
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
8248

8349
unsigned rank = dstTy.getRank();
8450
SmallVector<unsigned> repShape(rank);
@@ -124,9 +90,9 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
12490
scratchConfig.order = outOrd;
12591

12692
unsigned srcContigPerThread =
127-
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
93+
gpu::getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
12894
unsigned dstContigPerThread =
129-
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
95+
gpu::getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
13096
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
13197
// that we cannot do vectorization.
13298
unsigned innerDim = rank - 1;
@@ -135,12 +101,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135101
: srcContigPerThread;
136102
scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread;
137103

138-
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
104+
if (auto mma = mlir::dyn_cast<gpu::NvidiaMmaEncodingAttr>(srcLayout)) {
139105
if (mma.getVersionMajor() == 1) {
140106
// For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the
141107
// codegen.
142108
scratchConfig.inVec = srcContigPerThread;
143-
} else if (mlir::isa<BlockedEncodingAttr>(dstLayout)) {
109+
} else if (mlir::isa<gpu::BlockedEncodingAttr>(dstLayout)) {
144110
// when storing from mma layout and loading in blocked layout vectorizing
145111
// the load back gives better performance even if there is a
146112
// transposition.
@@ -186,12 +152,12 @@ class AllocationAnalysis {
186152
/// Initializes explicitly defined shared memory values for a given operation.
187153
void getExplicitValueSize(Operation *op) {
188154
for (Value result : op->getResults()) {
189-
auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>();
155+
auto alloc = result.getDefiningOp<gpu::LocalAllocOp>();
190156
if (alloc && alloc.isSharedMemoryAlloc()) {
191157
// Bytes could be a different value once we support padding or other
192158
// allocation policies.
193159
auto allocType = alloc.getType();
194-
auto shapePerCTA = triton::gpu::getShapePerCTA(allocType);
160+
auto shapePerCTA = gpu::getShapePerCTA(allocType);
195161
auto bytes = product<int64_t>(shapePerCTA) *
196162
allocType.getElementTypeBitWidth() / 8;
197163

@@ -218,31 +184,31 @@ class AllocationAnalysis {
218184
/// Initializes temporary shared memory for a given operation.
219185
void getScratchValueSize(Operation *op) {
220186
const size_t scratchAlignment = 128;
221-
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
187+
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
222188
ReduceOpHelper helper(reduceOp);
223189
unsigned bytes = helper.getScratchSizeInBytes();
224190
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
225191
scratchAlignment);
226-
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
192+
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
227193
ScanLoweringHelper helper(scanOp);
228194
unsigned bytes = helper.getScratchSizeInBytes();
229195
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
230196
scratchAlignment);
231-
} else if (auto histogram = dyn_cast<triton::HistogramOp>(op)) {
197+
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
232198
auto dstTy = histogram.getType();
233-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
199+
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
234200
op->getParentOfType<ModuleOp>());
235201
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
236202
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
237203
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
238204
scratchAlignment);
239-
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
205+
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
240206
auto srcTy = cvtLayout.getSrc().getType();
241207
auto dstTy = cvtLayout.getType();
242208
auto srcEncoding = srcTy.getEncoding();
243209
auto dstEncoding = dstTy.getEncoding();
244-
if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
245-
mlir::isa<SharedEncodingAttr>(dstEncoding)) {
210+
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
211+
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
246212
// Conversions from/to shared memory do not need scratch memory.
247213
return;
248214
}
@@ -253,12 +219,12 @@ class AllocationAnalysis {
253219
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
254220
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
255221
auto bytes =
256-
isa<triton::PointerType>(srcTy.getElementType())
222+
isa<PointerType>(srcTy.getElementType())
257223
? elems * kPtrBitWidth / 8
258224
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
259225
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
260226
scratchAlignment);
261-
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
227+
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
262228
auto value = op->getOperand(0);
263229
// only scalar requires scratch memory
264230
// make it explicit for readability
@@ -267,12 +233,10 @@ class AllocationAnalysis {
267233
} else {
268234
auto smemShape = getRepShapeForAtomic(op->getResult(0));
269235
auto elems = getNumScratchElements(smemShape);
270-
auto elemTy =
271-
cast<triton::PointerType>(value.getType()).getPointeeType();
236+
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
237+
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
272238
auto bytes =
273-
isa<triton::PointerType>(elemTy)
274-
? elems * kPtrBitWidth / 8
275-
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
239+
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
276240
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
277241
scratchAlignment);
278242
}

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(Location loc, Value src, Value dst,
19-
Value adaptorSrc,
20-
const SharedMemoryObject &smemObj,
21-
const LLVMTypeConverter *typeConverter,
22-
ConversionPatternRewriter &rewriter,
23-
const TargetInfoBase &targetInfo) {
18+
void lowerDistributedToShared(
19+
Location loc, Value src, Value dst, Value adaptorSrc,
20+
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21+
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22+
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
2423
auto srcTy = cast<RankedTensorType>(src.getType());
2524
auto dstTy = cast<MemDescType>(dst.getType());
2625
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
@@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
3332
auto dstStrides = smemObj.getStrides();
3433
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
3534
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
36-
loc, rewriter, targetInfo);
35+
loc, rewriter, targetInfo, llvmOpCount);
3736
}
3837

3938
struct LocalAllocOpConversion
@@ -200,12 +199,15 @@ struct LocalStoreOpConversion
200199
public:
201200
using ConvertOpToLLVMPattern<
202201
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
202+
using BackendCallbackType =
203+
decltype(BackendCallbacks::localStoreOpConversion);
203204

204205
LocalStoreOpConversion(const LLVMTypeConverter &converter,
205206
const TargetInfoBase &targetInfo,
207+
BackendCallbackType backendCallback,
206208
PatternBenefit benefit = 1)
207209
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
208-
targetInfo(targetInfo) {}
210+
targetInfo(targetInfo), backendCallback(backendCallback) {}
209211

210212
LogicalResult
211213
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -215,24 +217,36 @@ struct LocalStoreOpConversion
215217
getTypeConverter()->convertType(op.getDst().getType().getElementType());
216218
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
217219
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
220+
221+
std::pair<size_t, Type> llvmOpCount;
218222
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
219223
adaptor.getSrc(), smemObj, getTypeConverter(),
220-
rewriter, targetInfo);
224+
rewriter, targetInfo, &llvmOpCount);
225+
226+
if (backendCallback)
227+
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);
228+
221229
rewriter.eraseOp(op);
222230
return success();
223231
}
224232

225233
private:
226234
const TargetInfoBase &targetInfo;
235+
BackendCallbackType backendCallback;
227236
};
228237

229238
} // namespace
230239

231240
void mlir::triton::populateMemoryOpToLLVMPattern(
232241
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
233-
RewritePatternSet &patterns, PatternBenefit benefit) {
242+
RewritePatternSet &patterns, PatternBenefit benefit,
243+
std::optional<BackendCallbacks> backendCallbacks) {
234244
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
235245
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
236246
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
237-
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
247+
248+
auto backendCall =
249+
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
250+
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
251+
benefit);
238252
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
404404
Type elemLlvmTy, ArrayRef<Value> srcVals,
405405
Value smemBase, ArrayRef<Value> dstStrides,
406406
Location loc, RewriterBase &rewriter,
407-
const TargetInfoBase &target) {
407+
const TargetInfoBase &target,
408+
std::pair<size_t, Type> *const llvmOpCount) {
408409
bool success = emitTransferBetweenRegistersAndShared(
409410
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
410411
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
@@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
418419
store(vec, vecAddr)
419420
.setAlignment(vecTy.getNumElements() *
420421
elemLlvmTy.getIntOrFloatBitWidth() / 8);
422+
if (llvmOpCount) {
423+
++(llvmOpCount->first);
424+
llvmOpCount->second = vecTy;
425+
}
421426
});
427+
422428
if (!success)
423429
llvm::report_fatal_error("Failed to emit transfer from register to shared");
424430
}

0 commit comments

Comments
 (0)