@@ -421,43 +421,36 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
421421 SmallVector<Value> newOps;
422422
423423 if (tdescTy.getRank () == 2 ) {
424- SmallVector<int64_t > oneDShape (targetShape->begin (), targetShape->end () - 1 );
425- convertedIndiceTypes = getUnrolledTypes (indiceVecTy, oneDShape);
426- convertedIndiceVec = pack (indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
427- // Assume tdescTy, targetShape, and convertedIndiceVec are defined
428- int64_t outerDim = tdescTy.getShape ().back ();
429- int64_t innerDim = targetShape->back ();
430- int64_t numInnerLoops = outerDim / innerDim;
424+ SmallVector<int64_t > shape1D (targetShape->begin (), targetShape->end () - 1 );
425+ convertedIndiceTypes = getUnrolledTypes (indiceVecTy, shape1D);
426+ convertedIndiceVec = pack (indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
431427
432- // Get element size in bytes
433- int64_t elemSize = tdescTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
428+ int64_t wholeChunk = tdescTy.getShape ().back ();
429+ int64_t blockedChunk = targetShape->back ();
430+ int64_t numInnerLoops = wholeChunk / blockedChunk;
434431
435432 for (auto [indice, indiceType] : llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
436433 for (int64_t i = 0 ; i < numInnerLoops; ++i) {
437434 // Compute the offset
438- Value inc = rewriter.create <arith::ConstantIndexOp>(loc, i * innerDim );
435+ Value inc = rewriter.create <arith::ConstantIndexOp>(loc, i * blockedChunk );
439436 Value incVec = rewriter.create <vector::SplatOp>(loc, indiceType, inc);
440437 Value offsetIndice = rewriter.create <arith::AddIOp>(loc, indice, incVec);
441438
442- auto chunkSizeAttr = rewriter.getI64IntegerAttr (innerDim);
443- auto newOp = rewriter.create <xegpu::CreateDescOp>(
439+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
444440 loc, newTdescTy, op.getSource (), offsetIndice);
445441
446442 newOps.push_back (newOp);
447443 }
448444 }
449- } else if (tdescTy. getRank () == 1 ) {
445+ } else {
450446 convertedIndiceTypes = getUnrolledTypes (indiceVecTy, *targetShape);
451447 convertedIndiceVec = pack (indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
452448 for (auto indice : convertedIndiceVec) {
453449 auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy,
454450 op.getSource (), indice);
455451 newOps.push_back (newOp);
456452 }
457- } else {
458- // Unsupported rank for tensor descriptor
459- return failure ();
460- }
453+ }
461454
462455 Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
463456 rewriter.replaceOp (op, castOp);
@@ -493,10 +486,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
493486 SmallVector<Value> convertedTdescs = pack (
494487 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
495488
496- SmallVector<Type> convertedMaskTypes =
497- getUnrolledTypes (maskTy, *targetShape);
498- SmallVector<Value> convertedMasks =
499- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
489+ SmallVector<Type> convertedMaskTypes;
490+ SmallVector<Value> convertedMasks;
491+
492+ if (tdescTy.getRank () == 2 ) {
493+ convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
494+ SmallVector<Value> convertedMasks1D = pack (op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
495+ int64_t wholeChunk = tdescTy.getShape ().back ();
496+ int64_t blockedChunk = targetShape->back ();
497+ int64_t numInnerLoops = wholeChunk / blockedChunk;
498+
499+ for (auto mask : convertedMasks1D) {
500+ for (int64_t i = 0 ; i < numInnerLoops; ++i) {
501+ convertedMasks.push_back (mask);
502+ }
503+ }
504+ if (targetShape && targetShape->size () > 1 ) {
505+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
506+ newValueTy = valueTy.cloneWith (*targetShape, elemTy);
507+ }
508+ } else {
509+ convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape);
510+ convertedMasks = pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
511+ }
500512
501513 SmallVector<Value> newOps;
502514 for (auto [t, m] : llvm::zip (convertedTdescs, convertedMasks)) {
@@ -505,9 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
505517 op.getL2HintAttr (), op.getL3HintAttr ());
506518 newOps.push_back (newOp);
507519 }
508-
520+
509521 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
510-
511522 rewriter.replaceOp (op, castOp);
512523 return success ();
513524 }
@@ -521,7 +532,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
521532 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
522533
523534 // check if the tensor descriptor type is a 1d vector type
524- if (tdescTy.getRank () > 1 )
535+ if (tdescTy.getRank () > 2 )
525536 return failure ();
526537
527538 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -551,29 +562,48 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
551562 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
552563
553564 // check if the tensor descriptor type is a 1d vector type
554- if (tdescTy.getRank () > 1 )
565+ if (tdescTy.getRank () > 2 )
555566 return failure ();
556567
557568 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
558569
559570 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
560571 if (!targetShape)
561572 return failure ();
562-
563- SmallVector<Type> convertedValTypes =
564- getUnrolledTypes (valueTy, *targetShape);
573+
565574 SmallVector<Type> convertedTdescTypes =
566575 getUnrolledTypes (tdescTy, *targetShape);
567-
568- SmallVector<Value> convertedValues =
569- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
570576 SmallVector<Value> convertedTdescs = pack (
571577 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
572578
573- SmallVector<Type> convertedMaskTypes =
574- getUnrolledTypes (maskTy, *targetShape);
575- SmallVector<Value> convertedMasks =
576- pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
579+
580+ SmallVector<Type> convertedMaskTypes;
581+ SmallVector<Value> convertedMasks;
582+
583+ if (tdescTy.getRank () == 2 ) {
584+ convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
585+ SmallVector<Value> convertedMasks1D = pack (op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
586+ int64_t wholeChunk = tdescTy.getShape ().back ();
587+ int64_t blockedChunk = targetShape->back ();
588+ int64_t numInnerLoops = wholeChunk / blockedChunk;
589+
590+ for (auto mask : convertedMasks1D) {
591+ for (int64_t i = 0 ; i < numInnerLoops; ++i) {
592+ convertedMasks.push_back (mask);
593+ }
594+ }
595+
596+ std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
597+
598+ } else {
599+ convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape);
600+ convertedMasks = pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
601+ }
602+
603+ SmallVector<Type> convertedValTypes =
604+ getUnrolledTypes (valueTy, *targetShape);
605+ SmallVector<Value> convertedValues =
606+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
577607
578608 for (size_t i = 0 ; i < convertedValues.size (); ++i) {
579609 Value v = convertedValues[i];
@@ -597,7 +627,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
597627 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
598628
599629 // check if the tensor descriptor type is a 1d vector type
600- if (tdescTy.getRank () > 1 )
630+ if (tdescTy.getRank () > 2 )
601631 return failure ();
602632
603633 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -611,17 +641,36 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
611641
612642 TypedValue<::mlir::VectorType> offsetVec = op.getOffsets ();
613643 VectorType offsetVecTy = offsetVec.getType ();
614- SmallVector<Type> convertedOffsetTypes =
615- getUnrolledTypes (offsetVecTy, *targetShape);
616- SmallVector<Value> convertedOffsetVec =
617- pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
618-
644+ SmallVector<Type> convertedOffsetTypes;
645+ SmallVector<Value> convertedOffsetVec;
619646 SmallVector<Value> newOps;
647+
648+ if (tdescTy.getRank () == 2 ) {
649+ SmallVector<int64_t > shape1D (targetShape->begin (), targetShape->end () - 1 );
650+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
651+ SmallVector<Value> convertedOffsetVec1D = pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
652+
653+ int64_t wholeChunk = tdescTy.getShape ().back ();
654+ int64_t blockedChunk = targetShape->back ();
655+ int64_t numInnerLoops = wholeChunk / blockedChunk;
656+
657+ for (auto offset : convertedOffsetVec1D) {
658+ for (int64_t i = 0 ; i < numInnerLoops; ++i) {
659+ convertedOffsetVec.push_back (offset);
660+ }
661+ }
662+
663+ } else {
664+ convertedOffsetTypes = getUnrolledTypes (offsetVecTy, *targetShape);
665+ convertedOffsetVec = pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
666+ }
667+
620668 for (auto [t, o] : llvm::zip (convertedTdesc, convertedOffsetVec)) {
621669 auto newOp =
622670 rewriter.create <xegpu::UpdateOffsetOp>(loc, t.getType (), t, o);
623671 newOps.push_back (newOp);
624672 }
673+
625674 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
626675 rewriter.replaceOp (op, castOp);
627676 return success ();
0 commit comments