Skip to content

Commit a0ce2d9

Browse files
committed
Merge commit '1793a04d9cf80adfac9399fd67eac1bd10077b8b'
2 parents 12fdf3f + 1793a04 commit a0ce2d9

File tree

41 files changed

+1436
-954
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1436
-954
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,21 @@ class SharedMemoryObject {
352352
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
353353
RewriterBase &rewriter) const;
354354

355+
// Returns a mask representing all the bits of the memdesc offsets that
356+
// may be modified by an affine offset coming from a memdesc_subview.
357+
// The offsets are considered to be in the type of the memdesc.
358+
// For padded layouts, we return the offsets without padding.
359+
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);
360+
361+
// Returns whether the shared memory access had a memdesc_subview
362+
// that is rank-preserving (soon to be called memdesc_slice)
363+
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
364+
return getMaskSpanOffsets(srcTy) != 0;
365+
}
366+
367+
Value getShmemOffset(Location loc, RewriterBase &rewriter,
368+
triton::gpu::MemDescType srcTy) const;
369+
355370
// TODO(Keren): deprecate the method once AMD backend has cleaned up
356371
Value getCSwizzleOffset(int dim) const {
357372
assert(dim >= 0 && dim < offsets.size());
@@ -462,7 +477,6 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
462477
// -----------------------------------------------------------------------
463478
using LLVM::SharedMemoryObject;
464479
using ::mlir::LLVM::delinearize;
465-
using ::mlir::LLVM::SharedMemoryObject;
466480
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
467481
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
468482
using ::mlir::triton::gpu::BlockedEncodingAttr;
@@ -474,24 +488,6 @@ using ::mlir::triton::gpu::SliceEncodingAttr;
474488
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
475489
ArrayRef<Value> strides);
476490

477-
/// Extend 2d shared object to 3d.
478-
///
479-
/// If tensor has 3 dimensions, returns original shared object.
480-
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
481-
///
482-
/// This Function is used to simplify processing of 2d and 3d dot operands,
483-
/// particularly in the conversion of local_load operation.
484-
///
485-
/// \param rewriter
486-
/// \param loc
487-
/// \param smemObj
488-
/// \param shape shape of a tensor represented by smemObj
489-
/// \returns shared object describing 3d tensor
490-
SharedMemoryObject
491-
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
492-
SharedMemoryObject smemObj,
493-
ArrayRef<int64_t> shape);
494-
495491
// "Applies" the given layout by computing layout(indices) and returning the
496492
// resulting Values.
497493
//
@@ -568,7 +564,8 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
568564
SmallVector<Value>
569565
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
570566
ArrayRef<Value> valsArray, // Input for store, output for load
571-
Type llvmElemTy, Value smemBase,
567+
Type llvmElemTy, Value smemBase, Value affineOffset,
568+
uint64_t maskSpanAffineOffset,
572569
ConversionPatternRewriter &rewriter,
573570
const TargetInfoBase &targetInfo);
574571

@@ -578,20 +575,21 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
578575
SmallVector<Value> lowerLdSt(
579576
Location loc, MLIRContext *ctx, LinearLayout cvt,
580577
ArrayRef<Value> valsArray, // Input for store, output for load
581-
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
578+
Type llvmElemTy, Value smemBase, Value affineOffset,
579+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
582580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
583581
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
584582
ArrayRef<Value>, Value, int, VectorType)>
585583
lowerInst);
586584

587585
// Lower local_load/local_store via ld.shared/st.shared
588-
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
589-
// Map from registers to offset
590-
LinearLayout cvt, ArrayRef<Value> valsArray,
591-
// Input for store, output for load
592-
Type llvmElemTy, Value smemBase,
593-
ConversionPatternRewriter &rewriter,
594-
const TargetInfoBase &targetInfo);
586+
SmallVector<Value>
587+
lowerLocalLdSt(Location loc, MLIRContext *ctx,
588+
LinearLayout cvt, // Map from registers to offset
589+
ArrayRef<Value> valsArray, // Input for store, empty for load
590+
Type llvmElemTy, triton::gpu::MemDescType srcTy,
591+
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592+
const TargetInfoBase &targetInfo);
595593

596594
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
597595
RewriterBase &rewriter);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
2+
#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
namespace mlir::triton::nvidia_gpu {
8+
9+
LogicalResult verifyBarrierType(Operation *op,
10+
mlir::triton::gpu::MemDescType barrierType);
11+
12+
}
13+
14+
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/LayoutUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout);
129+
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
130130

131131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132132
// the same broadcasting as layout

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
202202

203203
assert(permutedInVals.size() == tileSize * nReps);
204204
SmallVector<Value> outVals;
205+
auto affineOffset = b.i32_val(0);
206+
auto maskSpanAffineOffset = 0;
205207
for (int i = 0; i < nReps; ++i) {
206208
if (i > 0)
207209
b.barrier();
@@ -210,11 +212,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
210212
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
211213
// Store
212214
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
213-
rewriter, targetInfo);
215+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
214216
b.barrier();
215217
// Load
216218
SmallVector<Value> tileOutVals = lowerLdStShared(
217-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, rewriter, targetInfo);
219+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
220+
maskSpanAffineOffset, rewriter, targetInfo);
218221
llvm::append_range(outVals, tileOutVals);
219222
}
220223

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
5353
auto kWarp = str_attr("warp");
5454
auto kOffset = str_attr("offset");
5555
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
56-
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, smemObj.getBase(), rewriter,
57-
targetInfo);
56+
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
57+
rewriter, targetInfo);
5858

5959
return success();
6060
}
@@ -177,10 +177,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
177177
auto regTy = cast<RankedTensorType>(regVal.getType());
178178
auto typeConverter = getTypeConverter();
179179

180-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
181-
loc, adaptor.getSrc(),
182-
typeConverter->convertType(memDescTy.getElementType()), rewriter);
183-
auto llvmElemTy = typeConverter->convertType(regTy.getElementType());
180+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
181+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
182+
llvmElemTy, rewriter);
184183

185184
// See [Legacy local_load/local_store]
186185
if (!targetInfo.isCuda()) {
@@ -206,8 +205,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206205
auto kOffset = str_attr("offset");
207206
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
208207

209-
auto outVals = lowerLocalLdSt(op.getLoc(), ctx, cvt, {}, llvmElemTy,
210-
smemObj.getBase(), rewriter, targetInfo);
208+
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209+
smemObj, rewriter, targetInfo);
211210

212211
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
213212
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)