@@ -277,7 +277,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout,
277277 const SharedMemoryObject &smemObj,
278278 triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
279279 Value regId, Value laneId, Value warpId, Value blockId,
280- Location loc, RewriterBase &rewriter) {
280+ Location loc, RewriterBase &rewriter,
281+ ArrayRef<int64_t > overwriteAllocSize) {
281282 auto b = TritonLLVMOpBuilder (loc, rewriter);
282283 MLIRContext *ctx = rewriter.getContext ();
283284 StringAttr kBlock = str_attr (" block" );
@@ -292,7 +293,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout,
292293
293294 auto smemBase = smemObj.getBase ();
294295 auto smemOffsets = smemObj.getOffsets ();
295- auto smemStrides = smemObj.getStrides (sharedTy, loc, rewriter);
296+ auto smemStrides =
297+ smemObj.getStrides (sharedTy, loc, rewriter, overwriteAllocSize);
296298 Value smemOffset;
297299 // When loading or storing to shared memory, we consider two cases for
298300 // performance reasons:
@@ -410,7 +412,7 @@ bool emitTransferBetweenRegistersAndShared(
410412 std::optional<int32_t > maxVecElems, const SharedMemoryObject &smemObj,
411413 Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
412414 std::function<void (VectorType, Value /* shmemAddr*/ )> perVectorCallback,
413- bool forceLane0) {
415+ bool forceLane0, ArrayRef<int64_t> overwriteAllocSize ) {
414416 MLIRContext *ctx = rewriter.getContext ();
415417 auto b = TritonLLVMOpBuilder (loc, rewriter);
416418
@@ -485,9 +487,10 @@ bool emitTransferBetweenRegistersAndShared(
485487 SmallVector<Value> ret;
486488 for (int i = 0 ; i < numElems / vecElems; i++) {
487489 auto regId = b.i32_val (i * vecElems);
488- auto vecAddr = getSmemVecAddr (
489- regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj,
490- sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter);
490+ auto vecAddr =
491+ getSmemVecAddr (regLayout, regToSharedLayout, invertAllocSharedLayout,
492+ smemObj, sharedTy, elemLlvmTy, regId, laneId, warpId,
493+ blockId, loc, rewriter, overwriteAllocSize);
491494 perVectorCallback (vecTy, vecAddr);
492495 }
493496 return true ;
@@ -499,12 +502,12 @@ bool emitTransferBetweenRegistersAndShared(
499502 const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
500503 const TargetInfoBase &target,
501504 std::function<void (VectorType, Value /* shmemAddr*/ )> perVectorCallback,
502- bool forceLane0) {
505+ bool forceLane0, ArrayRef<int64_t> overwriteAllocSize ) {
503506 auto regLayout = triton::gpu::toLinearLayout (registerTy.getShape (),
504507 registerTy.getEncoding ());
505508 return emitTransferBetweenRegistersAndShared (
506509 regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
507- target, perVectorCallback, forceLane0);
510+ target, perVectorCallback, forceLane0, overwriteAllocSize );
508511}
509512
510513SmallVector<Value> loadSharedToDistributed (triton::gpu::LocalLoadOp localLoadOp,
@@ -515,11 +518,28 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
515518 auto srcTy = localLoadOp.getSrc ().getType ();
516519 auto dstTy = localLoadOp.getResult ().getType ();
517520
521+ // We overwrite the alloc size if we are a subview to fix subviews in the
522+ // fastest dim
523+ SmallVector<int64_t > overwriteSmemAllocSize;
524+ auto src = localLoadOp.getSrc ();
525+ if (auto subView = src.getDefiningOp <triton::gpu::MemDescSubviewOp>()) {
526+ auto subViewSrcTy =
527+ dyn_cast<triton::gpu::MemDescType>(subView.getSrc ().getType ());
528+ if (subViewSrcTy) {
529+ auto origAllocSize = subViewSrcTy.getAllocShape ();
530+ auto srcAllocSize = srcTy.getAllocShape ();
531+ if (origAllocSize.size () == 3 && srcAllocSize.size () == 2 ) {
532+ overwriteSmemAllocSize = to_vector (origAllocSize.drop_front ());
533+ }
534+ }
535+ }
536+
518537 auto b = TritonLLVMOpBuilder (loc, rewriter);
519538 SmallVector<Value> ret;
520539 bool success = emitTransferBetweenRegistersAndShared (
521540 dstTy, srcTy, elemLlvmTy, /* maxVecElems=*/ std::nullopt , smemObj, loc,
522- rewriter, target, [&](VectorType vecTy, Value vecAddr) {
541+ rewriter, target,
542+ [&](VectorType vecTy, Value vecAddr) {
523543 auto vecVal = b.load (vecTy, vecAddr);
524544 target.localLoadOpAnnotation (localLoadOp, vecVal);
525545 vecVal.setAlignment (vecTy.getNumElements () *
@@ -528,7 +548,8 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
528548 for (int v = 0 ; v < vecTy.getNumElements (); v++) {
529549 ret.push_back (b.extract_element (elemLlvmTy, vecVal, b.i32_val (v)));
530550 }
531- });
551+ },
552+ false , overwriteSmemAllocSize);
532553 if (!success)
533554 llvm::report_fatal_error (" Failed to emit transfer from shared to register" );
534555
0 commit comments