@@ -426,7 +426,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
426426
427427 SmallVector<Value> newOps;
428428
429- // more indices is need when chunkSize > 1. Since a big load from one
429+ // More indices is need when chunkSize > 1. Since a big load from one
430430 // address could be break into multiple small loads.
431431 if (originalChunkSize > 1 ) {
432432 int64_t blockedChunkSize = targetShape->back ();
@@ -504,15 +504,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
504504 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505505
506506 for (auto mask : convertedMasks1D) {
507- for (int64_t i = 0 ; i < numNewChunks; ++i) {
507+ for (int64_t i = 0 ; i < numNewChunks; ++i)
508508 convertedMasks.push_back (mask);
509- }
510509 }
511510 // This is to handle the transpose effect when chunkSize > 1.
512- if (targetShape && targetShape->size () > 1 ) {
513- std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
514- newValueTy = valueTy.cloneWith (*targetShape, elemTy);
515- }
511+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
512+ newValueTy = valueTy.cloneWith (*targetShape, elemTy);
516513 } else {
517514 convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
518515 convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape,
@@ -540,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
540537 Location loc = op.getLoc ();
541538 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
542539
543- // check if the tensor descriptor type is a 1d vector type
544- if (tdescTy.getRank () > 2 )
540+ if (!tdescTy.isScattered ())
545541 return failure ();
546542
547543 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
0 commit comments