@@ -392,8 +392,10 @@ struct MemDescSubviewOpConversion
392392 Location loc = op->getLoc ();
393393 auto b = TritonLLVMOpBuilder (loc, rewriter);
394394 auto srcTy = op.getSrc ().getType ();
395+ auto destTy = op.getResult ().getType ();
395396 auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
396397 auto layoutOrder = getOrder (srcTy);
398+ auto enc = srcTy.getEncoding ();
397399
398400 // newBase = base + offset
399401 auto smemObj = getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
@@ -408,13 +410,49 @@ struct MemDescSubviewOpConversion
408410 for (int i = rankReduced; i < opOffsetVals.size (); i++) {
409411 offsetVals.push_back (b.add (opOffsetVals[i], smemObj.getOffsets ()[i]));
410412 }
411- // Compute the offset based on the original strides of the shared memory
412- // object
413- auto offset = dot (rewriter, loc, opOffsetVals, opSmemStrides);
414- auto elemPtrTy = smemObj.getBase ().getType ();
415- smemObj = SharedMemoryObject (
416- b.gep (elemPtrTy, llvmElemTy, smemObj.getBase (), offset), llvmElemTy,
417- offsetVals);
413+ Value offset = b.undef (i32_ty);
414+ auto allocShape = srcTy.getAllocShape ();
415+ bool isSimpleSubview =
416+ allocShape.take_back (destRank) == destTy.getShape () ||
417+ !isa<NVMMASharedEncodingAttr>(enc);
418+ if (!isSimpleSubview) {
419+ auto nvmmaEnc = cast<NVMMASharedEncodingAttr>(enc);
420+ assert (destRank >= 2 &&
421+ " Shape size should be >= 2 when using NVMMAShared encoding" );
422+ auto swizzleStride = b.i32_val ((nvmmaEnc.getSwizzlingByteWidth () * 8 ) /
423+ llvmElemTy.getIntOrFloatBitWidth ());
424+ offset = b.i32_val (0 );
425+ for (auto i = 0 ; i < opOffsetVals.size () - 2 ; ++i) {
426+ offset = b.add (offset, b.mul (opOffsetVals[i], opSmemStrides[i]));
427+ }
428+ // newOffset = offset - (stridedOff * swizzledStride + contigOff /
429+ // swizzledStride * tileSize + contigOff % swizzledStride)
430+ // + stridedInc * swizzledStride + contigInc / swizzledStride *
431+ // tileSize + contigInc % swizzledStride
432+ auto stridedDim = destRank - 1 - layoutOrder[0 ];
433+ auto contigDim = destRank - 1 - layoutOrder[1 ];
434+ auto stridedOff = smemObj.getOffsets ()[stridedDim];
435+ auto contigOff = smemObj.getOffsets ()[contigDim];
436+ auto stridedInc = offsetVals[stridedDim];
437+ auto contigInc = offsetVals[contigDim];
438+ int allocStridedDim = allocShape.size () - 1 - layoutOrder[0 ];
439+ auto tileSize =
440+ b.mul (b.i32_val (allocShape[allocStridedDim]), swizzleStride);
441+ offset = b.sub (offset, b.mul (stridedOff, swizzleStride));
442+ offset = b.sub (offset, b.mul (b.udiv (contigOff, swizzleStride), tileSize));
443+ offset = b.sub (offset, b.urem (contigOff, swizzleStride));
444+ offset = b.add (offset, b.mul (stridedInc, swizzleStride));
445+ offset = b.add (offset, b.mul (b.udiv (contigInc, swizzleStride), tileSize));
446+ offset = b.add (offset, b.urem (contigInc, swizzleStride));
447+ } else {
448+ // Compute the offset based on the original strides of the shared memory
449+ // object
450+ offset = dot (rewriter, loc, opOffsetVals, opSmemStrides);
451+ }
452+ auto base = smemObj.getBase ();
453+ auto elemPtrTy = base.getType ();
454+ smemObj = SharedMemoryObject (b.gep (elemPtrTy, llvmElemTy, base, offset),
455+ llvmElemTy, offsetVals);
418456 auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
419457 rewriter.replaceOp (op, retVal);
420458 return success ();
0 commit comments