Skip to content

Commit 6a6fb70

Browse files
AlexAUTravil-mobile
authored andcommitted
WA for incorrect strides in subview
1 parent a981b01 commit 6a6fb70

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,13 @@ class SharedMemoryObject {
376376
return types;
377377
}
378378

379-
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
380-
RewriterBase &rewriter) const {
379+
SmallVector<Value>
380+
getStrides(triton::gpu::MemDescType memDesc, Location loc,
381+
RewriterBase &rewriter,
382+
ArrayRef<int64_t> overwriteAllocSize = {}) const {
381383
auto allocShape = memDesc.getAllocShape();
384+
if (!overwriteAllocSize.empty())
385+
allocShape = overwriteAllocSize;
382386
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
383387
memDesc.getEncoding(), allocShape);
384388
auto layoutOrder = triton::gpu::getOrder(memDesc);
@@ -699,14 +703,14 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
699703
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
700704
const TargetInfoBase &target,
701705
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
702-
bool forceLane0 = false);
706+
bool forceLane0 = false, ArrayRef<int64_t> overwriteAllocSize = {});
703707

704708
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
705709
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
706710
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
707711
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
708712
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
709-
bool forceLane0 = false);
713+
bool forceLane0 = false, ArrayRef<int64_t> overwriteAllocSize = {});
710714

711715
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
712716
Type elemLlvmTy,

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
277277
const SharedMemoryObject &smemObj,
278278
triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
279279
Value regId, Value laneId, Value warpId, Value blockId,
280-
Location loc, RewriterBase &rewriter) {
280+
Location loc, RewriterBase &rewriter,
281+
ArrayRef<int64_t> overwriteAllocSize) {
281282
auto b = TritonLLVMOpBuilder(loc, rewriter);
282283
MLIRContext *ctx = rewriter.getContext();
283284
StringAttr kBlock = str_attr("block");
@@ -292,7 +293,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
292293

293294
auto smemBase = smemObj.getBase();
294295
auto smemOffsets = smemObj.getOffsets();
295-
auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter);
296+
auto smemStrides =
297+
smemObj.getStrides(sharedTy, loc, rewriter, overwriteAllocSize);
296298
Value smemOffset;
297299
// When loading or storing to shared memory, we consider two cases for
298300
// performance reasons:
@@ -410,7 +412,7 @@ bool emitTransferBetweenRegistersAndShared(
410412
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
411413
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
412414
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
413-
bool forceLane0) {
415+
bool forceLane0, ArrayRef<int64_t> overwriteAllocSize) {
414416
MLIRContext *ctx = rewriter.getContext();
415417
auto b = TritonLLVMOpBuilder(loc, rewriter);
416418

@@ -485,9 +487,10 @@ bool emitTransferBetweenRegistersAndShared(
485487
SmallVector<Value> ret;
486488
for (int i = 0; i < numElems / vecElems; i++) {
487489
auto regId = b.i32_val(i * vecElems);
488-
auto vecAddr = getSmemVecAddr(
489-
regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj,
490-
sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter);
490+
auto vecAddr =
491+
getSmemVecAddr(regLayout, regToSharedLayout, invertAllocSharedLayout,
492+
smemObj, sharedTy, elemLlvmTy, regId, laneId, warpId,
493+
blockId, loc, rewriter, overwriteAllocSize);
491494
perVectorCallback(vecTy, vecAddr);
492495
}
493496
return true;
@@ -499,12 +502,12 @@ bool emitTransferBetweenRegistersAndShared(
499502
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
500503
const TargetInfoBase &target,
501504
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback,
502-
bool forceLane0) {
505+
bool forceLane0, ArrayRef<int64_t> overwriteAllocSize) {
503506
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
504507
registerTy.getEncoding());
505508
return emitTransferBetweenRegistersAndShared(
506509
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
507-
target, perVectorCallback, forceLane0);
510+
target, perVectorCallback, forceLane0, overwriteAllocSize);
508511
}
509512

510513
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
@@ -515,11 +518,28 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
515518
auto srcTy = localLoadOp.getSrc().getType();
516519
auto dstTy = localLoadOp.getResult().getType();
517520

521+
// We overwrite the alloc size if we are a subview to fix subviews in the
522+
// fastest dim
523+
SmallVector<int64_t> overwriteSmemAllocSize;
524+
auto src = localLoadOp.getSrc();
525+
if (auto subView = src.getDefiningOp<triton::gpu::MemDescSubviewOp>()) {
526+
auto subViewSrcTy =
527+
dyn_cast<triton::gpu::MemDescType>(subView.getSrc().getType());
528+
if (subViewSrcTy) {
529+
auto origAllocSize = subViewSrcTy.getAllocShape();
530+
auto srcAllocSize = srcTy.getAllocShape();
531+
if (origAllocSize.size() == 3 && srcAllocSize.size() == 2) {
532+
overwriteSmemAllocSize = to_vector(origAllocSize.drop_front());
533+
}
534+
}
535+
}
536+
518537
auto b = TritonLLVMOpBuilder(loc, rewriter);
519538
SmallVector<Value> ret;
520539
bool success = emitTransferBetweenRegistersAndShared(
521540
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
522-
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
541+
rewriter, target,
542+
[&](VectorType vecTy, Value vecAddr) {
523543
auto vecVal = b.load(vecTy, vecAddr);
524544
target.localLoadOpAnnotation(localLoadOp, vecVal);
525545
vecVal.setAlignment(vecTy.getNumElements() *
@@ -528,7 +548,8 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
528548
for (int v = 0; v < vecTy.getNumElements(); v++) {
529549
ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v)));
530550
}
531-
});
551+
},
552+
false, overwriteSmemAllocSize);
532553
if (!success)
533554
llvm::report_fatal_error("Failed to emit transfer from shared to register");
534555

0 commit comments

Comments
 (0)