Skip to content

Commit 98b1409

Browse files
lezcanoapgoucher
andauthored
[BACKEND] Don't accumulate the offsets from memdesc_subview on the base (#7515)
Layouts coming from `memdesc_subview` are at heart affine layouts. This is because partial evaluation of a linear map on some variables gives you an affine map. More concretely, if `c` is a constant and `x` is a variable, we have that `A(x ^ c) = A(x) ^ A(c)` where `A(c)` is a constant. In `memdesc_subview`, `A` is the map from the matrix into shared memory (the inverse of the shared memory linear layout) and `c` are the offsets given in the IR of `memdesc_subview`. Previously `memdesc_subview` would advance the pointer as `ptr += A(c)`. This is incorrect, as the actual formula for the address is given by `ptr + (A(c) ^ A(x))`, so we compensated for this when computing the address by substracting the offsets and adding them as an xor https://github.com/triton-lang/triton/blob/8e52b2e483eb072149801443e2e33b0b72d32bc5/lib/Conversion/TritonGPUToLLVM/Utility.cpp#L604-L605 In this PR we untangle these issues by: 1. Just advancing the base ptr when we index over the pipelining dimensions (this will be split out into its own op in a future PR) 2. Exposing two methods on SharedMemoryObject to compute the affine part of the layout (what we called `A(c)` above) and compute the bits that `A(c)` may have on. This second part is useful to perform optimisations. 3. Decreeing that shared layout will always return the layout of the full allocShape (minus the pipelining dimension). This means that `toLinearLayout(MemDescType)` may return a layout of a different output shape as the tensor it represents. This should not be a problem because of the next poitn 4. To account for the previous point, we generalise `invertAndCompose` to solve systems `AX = B` where `A` may have an output dimension larger than `B`. In this case we still compute the linear map compositon `A^{-1}B`, which is well defined as `Im(B) \subset Im(A)` In the future we could consider always describing shared layouts as maps from the tensor into the offsets. This would allow us to represent the linear part of the affine map above via linear layouts, and we could simply return a layout of the correct shape from `toLinearLayout(MemDescType)`, rather than an oversized one. This makes also intuitive sense, as we will never want to store the same element in two different parts of the shared memory, so this map will always be well-defined, while its inverse may not be, as it is the case here. We also take the chance to give a good clean-up to `emitTransferBetweenRegistersAndShared` now that the logic is better defined. The generated code should be comparable or better than the previous one. --------- Co-authored-by: apgoucher <[email protected]>
1 parent d183197 commit 98b1409

File tree

21 files changed

+314
-438
lines changed

21 files changed

+314
-438
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,21 @@ class SharedMemoryObject {
352352
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
353353
RewriterBase &rewriter) const;
354354

355+
// Returns a mask representing all the bits of the memdesc offsets that
356+
// may be modified by an affine offset coming from a memdesc_subview.
357+
// The offsets are considered to be in the type of the memdesc.
358+
// For padded layouts, we return the offsets without padding.
359+
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);
360+
361+
// Returns whether the shared memory access had a memdesc_subview
362+
// that is rank-preserving (soon to be called memdesc_slice)
363+
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
364+
return getMaskSpanOffsets(srcTy) != 0;
365+
}
366+
367+
Value getShmemOffset(Location loc, RewriterBase &rewriter,
368+
triton::gpu::MemDescType srcTy) const;
369+
355370
// TODO(Keren): deprecate the method once AMD backend has cleaned up
356371
Value getCSwizzleOffset(int dim) const {
357372
assert(dim >= 0 && dim < offsets.size());
@@ -462,7 +477,6 @@ std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
462477
// -----------------------------------------------------------------------
463478
using LLVM::SharedMemoryObject;
464479
using ::mlir::LLVM::delinearize;
465-
using ::mlir::LLVM::SharedMemoryObject;
466480
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
467481
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
468482
using ::mlir::triton::gpu::BlockedEncodingAttr;
@@ -474,24 +488,6 @@ using ::mlir::triton::gpu::SliceEncodingAttr;
474488
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
475489
ArrayRef<Value> strides);
476490

477-
/// Extend 2d shared object to 3d.
478-
///
479-
/// If tensor has 3 dimensions, returns original shared object.
480-
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
481-
///
482-
/// This Function is used to simplify processing of 2d and 3d dot operands,
483-
/// particularly in the conversion of local_load operation.
484-
///
485-
/// \param rewriter
486-
/// \param loc
487-
/// \param smemObj
488-
/// \param shape shape of a tensor represented by smemObj
489-
/// \returns shared object describing 3d tensor
490-
SharedMemoryObject
491-
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
492-
SharedMemoryObject smemObj,
493-
ArrayRef<int64_t> shape);
494-
495491
// "Applies" the given layout by computing layout(indices) and returning the
496492
// resulting Values.
497493
//
@@ -568,7 +564,8 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
568564
SmallVector<Value>
569565
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
570566
ArrayRef<Value> valsArray, // Input for store, output for load
571-
Type llvmElemTy, Value smemBase,
567+
Type llvmElemTy, Value smemBase, Value affineOffset,
568+
uint64_t maskSpanAffineOffset,
572569
ConversionPatternRewriter &rewriter,
573570
const TargetInfoBase &targetInfo);
574571

@@ -578,20 +575,21 @@ lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
578575
SmallVector<Value> lowerLdSt(
579576
Location loc, MLIRContext *ctx, LinearLayout cvt,
580577
ArrayRef<Value> valsArray, // Input for store, output for load
581-
Type llvmElemTy, Value smemBase, ConversionPatternRewriter &rewriter,
578+
Type llvmElemTy, Value smemBase, Value affineOffset,
579+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
582580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
583581
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
584582
ArrayRef<Value>, Value, int, VectorType)>
585583
lowerInst);
586584

587585
// Lower local_load/local_store via ld.shared/st.shared
588-
SmallVector<Value> lowerLocalLdSt(Location loc, MLIRContext *ctx,
589-
// Map from registers to offset
590-
LinearLayout cvt, ArrayRef<Value> valsArray,
591-
// Input for store, output for load
592-
Type llvmElemTy, Value smemBase,
593-
ConversionPatternRewriter &rewriter,
594-
const TargetInfoBase &targetInfo);
586+
SmallVector<Value>
587+
lowerLocalLdSt(Location loc, MLIRContext *ctx,
588+
LinearLayout cvt, // Map from registers to offset
589+
ArrayRef<Value> valsArray, // Input for store, empty for load
590+
Type llvmElemTy, triton::gpu::MemDescType srcTy,
591+
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592+
const TargetInfoBase &targetInfo);
595593

596594
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
597595
RewriterBase &rewriter);

include/triton/Tools/LayoutUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128128
std::pair<int64_t, ColumnAction>
129-
actionAdditiveStrides(const LinearLayout &layout);
129+
actionAdditiveStrides(const LinearLayout &layout, uint64_t maskSpanOffsets);
130130

131131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
132132
// the same broadcasting as layout

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
202202

203203
assert(permutedInVals.size() == tileSize * nReps);
204204
SmallVector<Value> outVals;
205+
auto affineOffset = b.i32_val(0);
206+
auto maskSpanAffineOffset = 0;
205207
for (int i = 0; i < nReps; ++i) {
206208
if (i > 0)
207209
b.barrier();
@@ -210,11 +212,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
210212
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
211213
// Store
212214
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
213-
rewriter, targetInfo);
215+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
214216
b.barrier();
215217
// Load
216218
SmallVector<Value> tileOutVals = lowerLdStShared(
217-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, rewriter, targetInfo);
219+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
220+
maskSpanAffineOffset, rewriter, targetInfo);
218221
llvm::append_range(outVals, tileOutVals);
219222
}
220223

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
5353
auto kWarp = str_attr("warp");
5454
auto kOffset = str_attr("offset");
5555
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
56-
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, smemObj.getBase(), rewriter,
57-
targetInfo);
56+
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
57+
rewriter, targetInfo);
5858

5959
return success();
6060
}
@@ -177,10 +177,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
177177
auto regTy = cast<RankedTensorType>(regVal.getType());
178178
auto typeConverter = getTypeConverter();
179179

180-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
181-
loc, adaptor.getSrc(),
182-
typeConverter->convertType(memDescTy.getElementType()), rewriter);
183-
auto llvmElemTy = typeConverter->convertType(regTy.getElementType());
180+
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
181+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
182+
llvmElemTy, rewriter);
184183

185184
// See [Legacy local_load/local_store]
186185
if (!targetInfo.isCuda()) {
@@ -206,8 +205,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
206205
auto kOffset = str_attr("offset");
207206
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
208207

209-
auto outVals = lowerLocalLdSt(op.getLoc(), ctx, cvt, {}, llvmElemTy,
210-
smemObj.getBase(), rewriter, targetInfo);
208+
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209+
smemObj, rewriter, targetInfo);
211210

212211
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
213212
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)