@@ -498,12 +498,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
498498 if (originalChunkSize > 1 ) {
499499 targetMaskShape.pop_back ();
500500 convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
501- SmallVector<Value> convertedMasks1D = pack (
502- op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
503501 int64_t blockedChunkSize = targetShape->back ();
504502 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505503
506- for (auto mask : convertedMasks1D)
504+ // the mask is reused across the chunk_size dimension
505+ for (auto mask : pack (op.getMask (), convertedMaskTypes, targetMaskShape,
506+ loc, rewriter))
507507 convertedMasks.append (numNewChunks, mask);
508508
509509 newValueTy = valueTy.cloneWith (*targetShape, elemTy);
@@ -570,7 +570,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
570570 if (!targetShape)
571571 return failure ();
572572
573- SmallVector<int64_t > targetIndiceShape (*targetShape);
573+ SmallVector<int64_t > targetMaskShape (*targetShape);
574574 int64_t originalChunkSize = tdescTy.getChunkSize ();
575575
576576 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
@@ -584,18 +584,19 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
584584 SmallVector<Value> convertedMasks;
585585
586586 if (originalChunkSize > 1 ) {
587+ targetMaskShape.pop_back ();
587588 int64_t blockedChunkSize = targetShape->back ();
588589 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
589- convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
590- SmallVector<Value> convertedMasks1D = pack (
591- op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
590+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
592591
593- for (auto mask : convertedMasks1D)
592+ // the mask is reused across the chunk_size dimension
593+ for (auto mask : pack (op.getMask (), convertedMaskTypes, targetMaskShape,
594+ loc, rewriter))
594595 convertedMasks.append (numNewChunks, mask);
595596 } else {
596- convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape );
597- convertedMasks =
598- pack (op. getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
597+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape );
598+ convertedMasks = pack (op. getMask (), convertedMaskTypes, targetMaskShape,
599+ loc, rewriter);
599600 }
600601
601602 SmallVector<Type> convertedValTypes =
@@ -646,16 +647,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
646647 SmallVector<Value> newOps;
647648 int64_t originalChunkSize = tdescTy.getChunkSize ();
648649 if (originalChunkSize > 1 ) {
649- SmallVector<int64_t > shape1D (targetShape->begin (),
650- targetShape->end () - 1 );
651- convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
652- SmallVector<Value> convertedOffsetVec1D =
653- pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
650+ auto targetOffsetShape = ArrayRef<int64_t >(*targetShape).drop_back ();
651+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, targetOffsetShape);
654652
655653 int64_t blockedChunkSize = targetShape->back ();
656654 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
657-
658- for (auto offset : convertedOffsetVec1D)
655+ // the offset is reused across the chunk_size dimension
656+ for (auto offset : pack (offsetVec, convertedOffsetTypes,
657+ targetOffsetShape, loc, rewriter))
659658 convertedOffsetVec.append (numNewChunks, offset);
660659
661660 } else {
0 commit comments