@@ -144,7 +144,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
144144 triton::gpu::SharedMemorySpaceAttr::get (forOp.getContext ());
145145 ttg::MemDescType subviewTy = ttg::MemDescType::get (
146146 allocTy.getShape ().drop_front (), allocTy.getElementType (),
147- allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true );
147+ allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true ,
148+ /* allocShape=*/ allocTy.getAllocShape ());
148149 auto view = builder.createWithStage <ttg::MemDescSubviewOp>(
149150 loc, stage, clusterId, subviewTy, alloc, copyOffsets);
150151 Operation *copy = builder.createWithStage <ttg::AsyncCopyGlobalToLocalOp>(
@@ -232,7 +233,8 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
232233 copyOffsets[0 ] = insertIdx;
233234 ttg::MemDescType subviewTy = ttg::MemDescType::get (
234235 allocTy.getShape ().drop_front (), allocTy.getElementType (),
235- allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true );
236+ allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true ,
237+ /* allocShape=*/ allocTy.getAllocShape ());
236238 auto view = builder.createWithStage <ttg::MemDescSubviewOp>(
237239 loc, stage, clusterId, subviewTy, alloc, copyOffsets);
238240
@@ -526,7 +528,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
526528 bufferShape.insert (bufferShape.begin (), distance);
527529 Type memdescType = ttg::MemDescType::get (bufferShape, ty.getElementType (),
528530 sharedEnc, sharedMemorySpace,
529- /* mutableMemory*/ true );
531+ /* mutableMemory= */ true );
530532 Value alloc =
531533 builder.create <ttg::LocalAllocOp>(loadOp->getLoc (), memdescType, Value ());
532534 return alloc;
@@ -544,12 +546,13 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
544546 /* CTASplitNum=*/ {1 }, /* CTAOrder=*/ {0 });
545547 auto barrierEncoding =
546548 ttg::SharedEncodingAttr::get (context, 1 , 1 , 1 , {0 }, barrierCTALayout);
547- Type barrierMemDescType = ttg::MemDescType::get (
549+ auto barrierMemDescType = ttg::MemDescType::get (
548550 {distance}, builder.getI64Type (), barrierEncoding, sharedMemorySpace,
549551 /* mutableMemory=*/ true );
550- Type singleBarrierMemDescType =
551- ttg::MemDescType::get ({1 }, builder.getI64Type (), barrierEncoding,
552- sharedMemorySpace, /* mutableMemory=*/ true );
552+ Type singleBarrierMemDescType = ttg::MemDescType::get (
553+ {1 }, builder.getI64Type (), barrierEncoding, sharedMemorySpace,
554+ /* mutableMemory=*/ true ,
555+ /* allocShape=*/ barrierMemDescType.getAllocShape ());
553556 Value barrierAlloc =
554557 builder.create <ttg::LocalAllocOp>(loc, barrierMemDescType, Value ());
555558 for (unsigned i = 0 ; i < distance; i++) {
@@ -650,11 +653,11 @@ static void createTMABarrierAndWait(
650653 OpBuilderWithStage builder (forOp);
651654 Attribute sharedMemorySpace =
652655 ttg::SharedMemorySpaceAttr::get (builder.getContext ());
656+ auto allocTy = cast<ttg::MemDescType>(barrierAlloc.getType ());
653657 ttg::MemDescType barrierTy = ttg::MemDescType::get (
654- {1 }, builder.getI64Type (),
655- cast<ttg::MemDescType>(barrierAlloc.getType ()).getEncoding (),
656- sharedMemorySpace,
657- /* mutableMemory=*/ true );
658+ {1 }, builder.getI64Type (), allocTy.getEncoding (), sharedMemorySpace,
659+ /* mutableMemory=*/ true ,
660+ /* allocShape=*/ allocTy.getAllocShape ());
658661 builder.setInsertionPoint (group[0 ]->loadOp );
659662 Value barrier = builder.createWithStage <ttg::MemDescSubviewOp>(
660663 loc, stage, cluster, barrierTy, barrierAlloc,
@@ -835,14 +838,14 @@ static void invalidateBarriers(OpBuilder &builder,
835838 Attribute sharedMemorySpace =
836839 ttg::SharedMemorySpaceAttr::get (builder.getContext ());
837840 for (Value barrier : barriers) {
838- int numBarriers = cast<ttg::MemDescType>(barrier.getType ()).getShape ()[0 ];
841+ auto allocTy = cast<ttg::MemDescType>(barrier.getType ());
842+ int numBarriers = allocTy.getShape ()[0 ];
839843 for (int i = 0 ; i < numBarriers; i++) {
840844 Value idx = builder.create <arith::ConstantIntOp>(barrier.getLoc (), i, 32 );
841845 ttg::MemDescType barrierTy = ttg::MemDescType::get (
842- {1 }, builder.getI64Type (),
843- cast<ttg::MemDescType>(barrier.getType ()).getEncoding (),
844- sharedMemorySpace,
845- /* mutableMemory=*/ true );
846+ {1 }, builder.getI64Type (), allocTy.getEncoding (), sharedMemorySpace,
847+ /* mutableMemory=*/ true ,
848+ /* allocShape=*/ allocTy.getShape ());
846849 Value barrierView = builder.create <ttg::MemDescSubviewOp>(
847850 barrier.getLoc (), barrierTy, barrier, idx);
848851 builder.create <ttng::InvalBarrierOp>(barrier.getLoc (), barrierView);
0 commit comments