@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
402402 PatternRewriter &rewriter) const override {
403403 Location loc = op.getLoc ();
404404 xegpu::TensorDescType tdescTy = op.getType ();
405+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
406+ VectorType indiceVecTy = indiceVec.getType ();
405407
406- // check if the tensor descriptor type is a 1d vector type
407- if (tdescTy.getRank () > 1 )
408+ if (!tdescTy.isScattered ())
408409 return failure ();
409410
410411 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
411412 if (!targetShape)
412413 return failure ();
413414
414- auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
415-
416- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
417- VectorType indiceVecTy = indiceVec.getType ();
415+ SmallVector<int64_t > targetIndiceShape (*targetShape);
416+ int64_t originalChunkSize = tdescTy.getChunkSize ();
417+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418+ if (originalChunkSize > 1 )
419+ targetIndiceShape.pop_back ();
418420
421+ auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
419422 SmallVector<Type> convertedIndiceTypes =
420- getUnrolledTypes (indiceVecTy, *targetShape );
423+ getUnrolledTypes (indiceVecTy, targetIndiceShape );
421424 SmallVector<Value> convertedIndiceVec =
422- pack (indiceVec, convertedIndiceTypes, *targetShape , loc, rewriter);
425+ pack (indiceVec, convertedIndiceTypes, targetIndiceShape , loc, rewriter);
423426
424427 SmallVector<Value> newOps;
425- for (auto indice : convertedIndiceVec) {
426- auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy,
427- op.getSource (), indice);
428- newOps.push_back (newOp);
428+
429+ // More indices is need when chunkSize > 1. Since a big load from one
430+ // address could be break into multiple small loads.
431+ if (originalChunkSize > 1 ) {
432+ int64_t blockedChunkSize = targetShape->back ();
433+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434+
435+ for (auto [indice, indiceType] :
436+ llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
437+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
438+ // Compute the offset
439+ Value inc = rewriter.create <arith::ConstantIndexOp>(
440+ loc, i * blockedChunkSize);
441+ Value incVec = rewriter.create <vector::SplatOp>(loc, indiceType, inc);
442+ Value offsetIndice =
443+ rewriter.create <arith::AddIOp>(loc, indice, incVec);
444+
445+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
446+ loc, newTdescTy, op.getSource (), offsetIndice);
447+
448+ newOps.push_back (newOp);
449+ }
450+ }
451+ } else {
452+ for (auto indice : convertedIndiceVec) {
453+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
454+ loc, newTdescTy, op.getSource (), indice);
455+ newOps.push_back (newOp);
456+ }
429457 }
430458
431459 Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
444472 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
445473 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
446474
447- // check if the tensor descriptor type is a 1d vector type
448- if (tdescTy.getRank () > 1 )
475+ if (!tdescTy.isScattered ())
449476 return failure ();
450477
451- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
452-
453478 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
454479 if (!targetShape)
455480 return failure ();
456481
482+ SmallVector<int64_t > targetMaskShape (*targetShape);
483+ int64_t originalChunkSize = tdescTy.getChunkSize ();
484+
485+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
486+
457487 Type elemTy = tdescTy.getElementType ();
458488 VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
459489
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
462492 SmallVector<Value> convertedTdescs = pack (
463493 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
464494
465- SmallVector<Type> convertedMaskTypes =
466- getUnrolledTypes (maskTy, *targetShape);
467- SmallVector<Value> convertedMasks =
468- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
495+ SmallVector<Type> convertedMaskTypes;
496+ SmallVector<Value> convertedMasks;
497+
498+ if (originalChunkSize > 1 ) {
499+ targetMaskShape.pop_back ();
500+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
501+ SmallVector<Value> convertedMasks1D = pack (
502+ op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
503+ int64_t blockedChunkSize = targetShape->back ();
504+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
505+
506+ for (auto mask : convertedMasks1D) {
507+ for (int64_t i = 0 ; i < numNewChunks; ++i)
508+ convertedMasks.push_back (mask);
509+ }
510+ // This is to handle the transpose effect when chunkSize > 1.
511+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
512+ newValueTy = valueTy.cloneWith (*targetShape, elemTy);
513+ } else {
514+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
515+ convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape,
516+ loc, rewriter);
517+ }
469518
470519 SmallVector<Value> newOps;
471520 for (auto [t, m] : llvm::zip (convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
476525 }
477526
478527 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
479-
480528 rewriter.replaceOp (op, castOp);
481529 return success ();
482530 }
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
489537 Location loc = op.getLoc ();
490538 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
491539
492- // check if the tensor descriptor type is a 1d vector type
493- if (tdescTy.getRank () > 1 )
540+ if (!tdescTy.isScattered ())
494541 return failure ();
495542
496543 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
519566 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
520567 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
521568
522- // check if the tensor descriptor type is a 1d vector type
523- if (tdescTy.getRank () > 1 )
569+ if (!tdescTy.isScattered ())
524570 return failure ();
525571
526- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
527-
528572 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
529573 if (!targetShape)
530574 return failure ();
531575
532- SmallVector<Type> convertedValTypes =
533- getUnrolledTypes (valueTy, *targetShape);
576+ SmallVector<int64_t > targetIndiceShape (*targetShape);
577+ int64_t originalChunkSize = tdescTy.getChunkSize ();
578+
579+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
580+
534581 SmallVector<Type> convertedTdescTypes =
535582 getUnrolledTypes (tdescTy, *targetShape);
536-
537- SmallVector<Value> convertedValues =
538- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
539583 SmallVector<Value> convertedTdescs = pack (
540584 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
541585
542- SmallVector<Type> convertedMaskTypes =
543- getUnrolledTypes (maskTy, *targetShape);
544- SmallVector<Value> convertedMasks =
545- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
586+ SmallVector<Type> convertedMaskTypes;
587+ SmallVector<Value> convertedMasks;
588+
589+ if (originalChunkSize > 1 ) {
590+ int64_t blockedChunkSize = targetShape->back ();
591+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
592+ convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
593+ SmallVector<Value> convertedMasks1D = pack (
594+ op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
595+
596+ for (auto mask : convertedMasks1D) {
597+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
598+ convertedMasks.push_back (mask);
599+ }
600+ }
601+ // This is to handle the transpose effect when chunkSize > 1.
602+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
603+
604+ } else {
605+ convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape);
606+ convertedMasks =
607+ pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
608+ }
609+
610+ SmallVector<Type> convertedValTypes =
611+ getUnrolledTypes (valueTy, *targetShape);
612+ SmallVector<Value> convertedValues =
613+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
546614
547615 for (size_t i = 0 ; i < convertedValues.size (); ++i) {
548616 Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
565633 Location loc = op.getLoc ();
566634 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
567635
568- // check if the tensor descriptor type is a 1d vector type
569- if (tdescTy.getRank () > 1 )
636+ if (tdescTy.getRank () > 2 )
637+ return failure ();
638+
639+ if (!tdescTy.isScattered ())
570640 return failure ();
571641
572642 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
580650
581651 TypedValue<::mlir::VectorType> offsetVec = op.getOffsets ();
582652 VectorType offsetVecTy = offsetVec.getType ();
583- SmallVector<Type> convertedOffsetTypes =
584- getUnrolledTypes (offsetVecTy, *targetShape);
585- SmallVector<Value> convertedOffsetVec =
586- pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
587-
653+ SmallVector<Type> convertedOffsetTypes;
654+ SmallVector<Value> convertedOffsetVec;
588655 SmallVector<Value> newOps;
656+ int64_t originalChunkSize = tdescTy.getChunkSize ();
657+ if (originalChunkSize > 1 ) {
658+ SmallVector<int64_t > shape1D (targetShape->begin (),
659+ targetShape->end () - 1 );
660+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
661+ SmallVector<Value> convertedOffsetVec1D =
662+ pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
663+
664+ int64_t blockedChunkSize = targetShape->back ();
665+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
666+
667+ for (auto offset : convertedOffsetVec1D) {
668+ for (int64_t i = 0 ; i < numNewChunks; ++i) {
669+ convertedOffsetVec.push_back (offset);
670+ }
671+ }
672+
673+ } else {
674+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, *targetShape);
675+ convertedOffsetVec =
676+ pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
677+ }
678+
589679 for (auto [t, o] : llvm::zip (convertedTdesc, convertedOffsetVec)) {
590680 auto newOp =
591681 rewriter.create <xegpu::UpdateOffsetOp>(loc, t.getType (), t, o);
0 commit comments