Skip to content

Commit 3c3023c

Browse files
committed
address review feedback
1 parent 41b839d commit 3c3023c

File tree

3 files changed

+73
-81
lines changed

3 files changed

+73
-81
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
426426

427427
SmallVector<Value> newOps;
428428

429-
// more indices is need when chunkSize > 1. Since a big load from one
429+
// More indices is need when chunkSize > 1. Since a big load from one
430430
// address could be break into multiple small loads.
431431
if (originalChunkSize > 1) {
432432
int64_t blockedChunkSize = targetShape->back();
@@ -504,15 +504,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
504504
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505505

506506
for (auto mask : convertedMasks1D) {
507-
for (int64_t i = 0; i < numNewChunks; ++i) {
507+
for (int64_t i = 0; i < numNewChunks; ++i)
508508
convertedMasks.push_back(mask);
509-
}
510509
}
511510
// This is to handle the transpose effect when chunkSize > 1.
512-
if (targetShape && targetShape->size() > 1) {
513-
std::swap((*targetShape)[0], (*targetShape)[1]);
514-
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
515-
}
511+
std::swap((*targetShape)[0], (*targetShape)[1]);
512+
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
516513
} else {
517514
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
518515
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
@@ -540,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
540537
Location loc = op.getLoc();
541538
xegpu::TensorDescType tdescTy = op.getTensorDescType();
542539

543-
// check if the tensor descriptor type is a 1d vector type
544-
if (tdescTy.getRank() > 2)
540+
if (!tdescTy.isScattered())
545541
return failure();
546542

547543
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

0 commit comments

Comments
 (0)