Skip to content

Commit c44a95b

Browse files
Merge commit '9e900890fcda9c017cbc731768de4c9e1044f017'
2 parents 3ffaa2f + 9e90089 commit c44a95b

File tree

37 files changed

+486
-940
lines changed

37 files changed

+486
-940
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipeline();
8989
mlir::registerTritonAMDGPUStreamPipelineV2();
9090
mlir::registerTritonAMDGPUCanonicalizePointers();
91-
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92-
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9391

9492
// TODO: register Triton & TritonGPU passes
9593
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,15 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30-
struct BackendCallbacks {
31-
/**
32-
* A backend-specific callback for appending auxiliary data during
33-
* `LocalStoreOp` conversion.
34-
*
35-
* @param[in] op The reference to the re-written `LocalStoreOp`.
36-
* @param[in] count The number of issued LLVM instructions.
37-
* @param[in] type The input type of issued LLVM instructions.
38-
*/
39-
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40-
Type llvmOpType)>
41-
localStoreOpConversion = nullptr;
42-
};
43-
4430
void populateElementwiseOpToLLVMPatterns(
4531
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
4632
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
4733
PatternBenefit benefit);
4834

49-
// The given callback is invoked at the end of a successful rewrite. The
50-
// callback receives 1) the current source op, 2) the number of issued LLVM
51-
// instructions and 3) their input types. Each MLIR backend can provide a
52-
// callback and, thus, handle backend-specific behaviors.
53-
void populateMemoryOpToLLVMPattern(
54-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55-
RewritePatternSet &patterns, PatternBenefit benefit,
56-
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
35+
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36+
const TargetInfoBase &targetInfo,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit);
5739

5840
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
5941
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(
1370-
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371-
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
1369+
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370+
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371+
Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter,
1373+
const TargetInfoBase &target);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
115115

116116
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118-
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
118+
// FIXME This is NOT entirely correct
119+
// This should be getElemOrder, but we don't have such a method
120+
// TODO Implement getElemOrder and make sure it's consistent with
121+
// getContigPerThread
122+
auto inOrd = gpu::getThreadOrder(srcLayout);
123+
auto outOrd = gpu::getThreadOrder(dstLayout);
119124
scratchConfig.order = outOrd;
120125

121126
unsigned srcContigPerThread =

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
404404
}
405405
return true;
406406
}
407+
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
408+
if (auto nvidiaMma =
409+
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
410+
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
411+
return false;
412+
}
413+
if (useLegacyMMAConversion) {
414+
return false;
415+
}
416+
// FIXME [Dot LL]
417+
// Enabling LL path for buggy kWidth path
418+
bool largeKWidth =
419+
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
420+
return largeKWidth && nvidiaMma.isAmpere();
421+
}
422+
}
407423
if (isa<BlockedEncodingAttr>(layout)) {
408424
return true;
409425
}
@@ -460,6 +476,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
460476
}
461477
}
462478

479+
// FIXME [Dot LL]
480+
// We know it's just for largeKWidth case in Ampere
481+
// In this case, we need to pack the outputs into i32
482+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
483+
auto concat = [&](Value a, Value b) {
484+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
485+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
486+
};
487+
488+
SmallVector<Value> outVals32(outVals.size() / 2);
489+
for (int i = 0; i < outVals32.size(); ++i) {
490+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
491+
}
492+
outVals = outVals32;
493+
}
494+
463495
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
464496
op.getType());
465497
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93+
// FIXME [Dot LL]
94+
// We support this one via LLs, as the LocalLoad path is buggy
95+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96+
bool largeKWidth =
97+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98+
if (mma.isAmpere() && largeKWidth) {
99+
return;
100+
}
101+
}
102+
93103
Attribute sharedMemorySpace =
94104
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
95105
auto tmpType = MemDescType::get(

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(
19-
Location loc, Value src, Value dst, Value adaptorSrc,
20-
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22-
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
18+
void lowerDistributedToShared(Location loc, Value src, Value dst,
19+
Value adaptorSrc,
20+
const SharedMemoryObject &smemObj,
21+
const LLVMTypeConverter *typeConverter,
22+
ConversionPatternRewriter &rewriter,
23+
const TargetInfoBase &targetInfo) {
2324
auto srcTy = cast<RankedTensorType>(src.getType());
2425
auto dstTy = cast<MemDescType>(dst.getType());
2526
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
@@ -32,7 +33,7 @@ void lowerDistributedToShared(
3233
auto dstStrides = smemObj.getStrides();
3334
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
3435
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
35-
loc, rewriter, targetInfo, llvmOpCount);
36+
loc, rewriter, targetInfo);
3637
}
3738

3839
struct LocalAllocOpConversion
@@ -184,15 +185,12 @@ struct LocalStoreOpConversion
184185
public:
185186
using ConvertOpToLLVMPattern<
186187
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
187-
using BackendCallbackType =
188-
decltype(BackendCallbacks::localStoreOpConversion);
189188

190189
LocalStoreOpConversion(const LLVMTypeConverter &converter,
191190
const TargetInfoBase &targetInfo,
192-
BackendCallbackType backendCallback,
193191
PatternBenefit benefit = 1)
194192
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
195-
targetInfo(targetInfo), backendCallback(backendCallback) {}
193+
targetInfo(targetInfo) {}
196194

197195
LogicalResult
198196
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -202,36 +200,24 @@ struct LocalStoreOpConversion
202200
getTypeConverter()->convertType(op.getDst().getType().getElementType());
203201
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
204202
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
205-
206-
std::pair<size_t, Type> llvmOpCount;
207203
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
208204
adaptor.getSrc(), smemObj, getTypeConverter(),
209-
rewriter, targetInfo, &llvmOpCount);
210-
211-
if (backendCallback)
212-
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);
213-
205+
rewriter, targetInfo);
214206
rewriter.eraseOp(op);
215207
return success();
216208
}
217209

218210
private:
219211
const TargetInfoBase &targetInfo;
220-
BackendCallbackType backendCallback;
221212
};
222213

223214
} // namespace
224215

225216
void mlir::triton::populateMemoryOpToLLVMPattern(
226217
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
227-
RewritePatternSet &patterns, PatternBenefit benefit,
228-
std::optional<BackendCallbacks> backendCallbacks) {
218+
RewritePatternSet &patterns, PatternBenefit benefit) {
229219
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
230220
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
231221
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
232-
233-
auto backendCall =
234-
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
235-
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
236-
benefit);
222+
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
237223
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,7 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
404404
Type elemLlvmTy, ArrayRef<Value> srcVals,
405405
Value smemBase, ArrayRef<Value> dstStrides,
406406
Location loc, RewriterBase &rewriter,
407-
const TargetInfoBase &target,
408-
std::pair<size_t, Type> *const llvmOpCount) {
407+
const TargetInfoBase &target) {
409408
bool success = emitTransferBetweenRegistersAndShared(
410409
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
411410
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
@@ -419,12 +418,7 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
419418
store(vec, vecAddr)
420419
.setAlignment(vecTy.getNumElements() *
421420
elemLlvmTy.getIntOrFloatBitWidth() / 8);
422-
if (llvmOpCount) {
423-
++(llvmOpCount->first);
424-
llvmOpCount->second = vecTy;
425-
}
426421
});
427-
428422
if (!success)
429423
llvm::report_fatal_error("Failed to emit transfer from register to shared");
430424
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Support/LLVM.h"
1212
#include "triton/Analysis/Utility.h"
1313
#include "triton/Dialect/Triton/IR/Utility.h"
14+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1415
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1617
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -237,8 +238,31 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
237238
return resOrder;
238239
}
239240

241+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
242+
bool kMajor) {
243+
// kMajor: if true, the matrix is fastest-running on k,
244+
// otherwise it is on m (resp. n)
245+
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246+
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247+
// batch (if rank == 3) is always the slowest running dimension
248+
assert(rank == 2 || rank == 3);
249+
assert(opIdx == 0 || opIdx == 1);
250+
SmallVector<unsigned> order(rank);
251+
std::iota(order.rbegin(), order.rend(), 0);
252+
// If opIdx is 1 and kMajor is true, the order is [0, 1]
253+
// (resp. [1, 2, 0] if rank == 3)
254+
// Same if opIdx is 0 and kMajor is false
255+
if (bool(opIdx) == kMajor) {
256+
std::swap(order[0], order[1]);
257+
}
258+
return order;
259+
}
260+
240261
SmallVector<unsigned> getWarpOrder(Attribute layout) {
241262
auto order = getOrder(layout);
263+
// FIXME: This mmaLayout if should just return
264+
// getOrderForDotOperand(0, order.size(), kMajor=false)
265+
// as mma has the same order as DotOperand(opIdx=0)
242266
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
243267
if (mmaLayout.isHopper()) {
244268
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -247,30 +271,9 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
247271
order.erase(it);
248272
order.insert(order.begin(), 0);
249273
}
250-
}
251-
return order;
252-
}
253-
254-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
255-
SmallVector<unsigned> order(rank);
256-
// The 'order' field typically represents a descending sorted array of
257-
// dimensions based on contiguity. For instance, in axisInfo utilities that
258-
// retrieve tensor contiguity, it's assumed that the dimension with the
259-
// highest contiguity corresponds to order[0].
260-
//
261-
// The relation between contiguity and order is only relevant if the layout
262-
// interfaces with HBM, as is the case when we load tensor from HBM to
263-
// registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264-
// following assumptions about tensor layouts:
265-
// - Tensor A (opIdx == 0) is considered to be row-major.
266-
// - Tensor B (opIdx == 1) is considered to be column-major.
267-
//
268-
// Based on these assumptions, we define the following orders:
269-
// - For opIdx == 0, we assume an order of [1, 0].
270-
// - For opIdx == 1, we assume an order of [0, 1].
271-
std::iota(order.rbegin(), order.rend(), 0);
272-
if (opIdx == 1) {
273-
std::swap(order[0], order[1]);
274+
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
275+
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
276+
/*kMajor*/ false);
274277
}
275278
return order;
276279
}
@@ -287,13 +290,12 @@ SmallVector<unsigned> getOrder(Attribute layout) {
287290
return order;
288291
}
289292
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
290-
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
291-
SmallVector<unsigned> order(rank);
293+
auto rank = dotLayout.getWarpsPerCTA().size();
292294
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
293-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
294-
} else {
295-
std::iota(order.rbegin(), order.rend(), 0);
295+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
296296
}
297+
SmallVector<unsigned> order(rank);
298+
std::iota(order.rbegin(), order.rend(), 0);
297299
return order;
298300
}
299301
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1059,7 +1061,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591061
return ::getWarpOrder(*this);
10601062
}
10611063
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1062-
return ::getOrder(*this);
1064+
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1065+
/*kMajor*/ true);
10631066
}
10641067
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
10651068
ArrayRef<int64_t> tensorShape) const {
@@ -2042,6 +2045,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
20422045
int opIdx) const {
20432046
auto rank = shape.size();
20442047
auto warpsPerCTA = getWarpsPerCTA();
2048+
20452049
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
20462050
int numRepBatch =
20472051
rank == 3

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,15 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827827

828828
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829829
return dotOperandMfmaToLinearLayout(*this, shape);
830-
}
831-
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
830+
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
831+
// FIXME [Dot LL]
832+
// Do this unconditionally
833+
auto largeKWidth = getKWidth() == 8;
834+
if (mma.isAmpere() && largeKWidth) {
835+
return ampereDotToLinearLayout(shape, *this);
836+
}
837+
} else if (auto dpasLayout =
838+
llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832839
return dotOperandDpasToLinearLayout(*this, shape);
833840
}
834841

0 commit comments

Comments
 (0)