Skip to content

Commit b766535

Browse files
committed
clean-up
1 parent 2e72e49 commit b766535

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
498498
if (originalChunkSize > 1) {
499499
targetMaskShape.pop_back();
500500
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
501-
SmallVector<Value> convertedMasks1D = pack(
502-
op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
503501
int64_t blockedChunkSize = targetShape->back();
504502
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505503

506-
for (auto mask : convertedMasks1D)
504+
// the mask is reused across the chunk_size dimension
505+
for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
506+
loc, rewriter))
507507
convertedMasks.append(numNewChunks, mask);
508508

509509
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -570,7 +570,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
570570
if (!targetShape)
571571
return failure();
572572

573-
SmallVector<int64_t> targetIndiceShape(*targetShape);
573+
SmallVector<int64_t> targetMaskShape(*targetShape);
574574
int64_t originalChunkSize = tdescTy.getChunkSize();
575575

576576
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -584,18 +584,19 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
584584
SmallVector<Value> convertedMasks;
585585

586586
if (originalChunkSize > 1) {
587+
targetMaskShape.pop_back();
587588
int64_t blockedChunkSize = targetShape->back();
588589
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
589-
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
590-
SmallVector<Value> convertedMasks1D = pack(
591-
op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
590+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
592591

593-
for (auto mask : convertedMasks1D)
592+
// the mask is reused across the chunk_size dimension
593+
for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
594+
loc, rewriter))
594595
convertedMasks.append(numNewChunks, mask);
595596
} else {
596-
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
597-
convertedMasks =
598-
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
597+
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
598+
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
599+
loc, rewriter);
599600
}
600601

601602
SmallVector<Type> convertedValTypes =
@@ -646,16 +647,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
646647
SmallVector<Value> newOps;
647648
int64_t originalChunkSize = tdescTy.getChunkSize();
648649
if (originalChunkSize > 1) {
649-
SmallVector<int64_t> shape1D(targetShape->begin(),
650-
targetShape->end() - 1);
651-
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
652-
SmallVector<Value> convertedOffsetVec1D =
653-
pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
650+
auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
651+
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
654652

655653
int64_t blockedChunkSize = targetShape->back();
656654
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
657-
658-
for (auto offset : convertedOffsetVec1D)
655+
// the offset is reused across the chunk_size dimension
656+
for (auto offset : pack(offsetVec, convertedOffsetTypes,
657+
targetOffsetShape, loc, rewriter))
659658
convertedOffsetVec.append(numNewChunks, offset);
660659

661660
} else {

0 commit comments

Comments
 (0)