@@ -267,7 +267,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
267267 return success ();
268268 }
269269};
270-
270+ /*
271271struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
272272 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
273273 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
@@ -298,6 +298,49 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
298298 return success();
299299 }
300300};
301+ */
302+
303+ struct UnrollStoreNdOp : public UnrollPattern <xegpu::StoreNdOp> {
304+ using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
305+ LogicalResult matchAndRewrite (xegpu::StoreNdOp op,
306+ PatternRewriter &rewriter) const override {
307+ Location loc = op.getLoc ();
308+ VectorType valueTy = op.getValueType ();
309+ xegpu::TensorDescType tdescTy = op.getTensorDescType ();
310+
311+ std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
312+ LDBG (" UnrollStoreNdOp: targetShape present? " << (targetShape.has_value () ? " yes" : " no" ));
313+ if (!targetShape)
314+ return failure ();
315+
316+ LDBG (" targetShape: " );
317+ for (auto v : *targetShape) LDBG (" " << v);
318+
319+ SmallVector<Type> convertedValTypes =
320+ getUnrolledTypes (valueTy, *targetShape);
321+ LDBG (" convertedValTypes size: " << convertedValTypes.size ());
322+ SmallVector<Type> convertedTdescTypes =
323+ getUnrolledTypes (tdescTy, *targetShape);
324+ LDBG (" convertedTdescTypes size: " << convertedTdescTypes.size ());
325+
326+ SmallVector<Value> convertedValues =
327+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
328+ LDBG (" convertedValues size: " << convertedValues.size ());
329+ SmallVector<Value> convertedTdescs = pack (
330+ op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
331+ LDBG (" convertedTdescs size: " << convertedTdescs.size ());
332+
333+ for (auto [v, t] : llvm::zip (convertedValues, convertedTdescs)) {
334+ LDBG (" Creating StoreNdOp with value: " << v << " , tdesc: " << t);
335+ rewriter.create <xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr (),
336+ op.getL2HintAttr (), op.getL3HintAttr ());
337+ }
338+
339+ LDBG (" Erasing original StoreNdOp: " << op);
340+ rewriter.eraseOp (op);
341+ return success ();
342+ }
343+ };
301344
302345struct UnrollDpasOp : public UnrollPattern <xegpu::DpasOp> {
303346 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
@@ -402,37 +445,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
402445 PatternRewriter &rewriter) const override {
403446 Location loc = op.getLoc ();
404447 xegpu::TensorDescType tdescTy = op.getType ();
448+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
449+ VectorType indiceVecTy = indiceVec.getType ();
405450
406- // check if the tensor descriptor type is a 1d vector type
407- if (tdescTy.getRank () > 2 )
451+ if (!tdescTy.isScattered ())
408452 return failure ();
409-
453+
410454 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
411455 if (!targetShape)
412456 return failure ();
457+
458+ SmallVector<int64_t > targetIndiceShape (*targetShape);
459+ int64_t originalChunkSize = tdescTy.getChunkSize ();
460+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
461+ if (originalChunkSize > 1 )
462+ targetIndiceShape.pop_back ();
413463
414464 auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
415-
416- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets ();
417- VectorType indiceVecTy = indiceVec.getType ();
418-
419- SmallVector<Type> convertedIndiceTypes;
420- SmallVector<Value> convertedIndiceVec;
465+ SmallVector<Type> convertedIndiceTypes =
466+ getUnrolledTypes (indiceVecTy, targetIndiceShape);
467+ SmallVector<Value> convertedIndiceVec =
468+ pack (indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
469+
421470 SmallVector<Value> newOps;
422471
423- if (tdescTy.getRank () == 2 ) {
424- SmallVector<int64_t > shape1D (targetShape->begin (), targetShape->end () - 1 );
425- convertedIndiceTypes = getUnrolledTypes (indiceVecTy, shape1D);
426- convertedIndiceVec = pack (indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
427-
428- int64_t wholeChunk = tdescTy.getShape ().back ();
429- int64_t blockedChunk = targetShape->back ();
430- int64_t numInnerLoops = wholeChunk / blockedChunk;
472+ // more indices is need when chunkSize > 1. Since a big load from one
473+ // address could be break into multiple small loads.
474+ if (originalChunkSize > 1 ) {
475+ int64_t blockedChunkSize = targetShape->back ();
476+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
431477
432478 for (auto [indice, indiceType] : llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
433- for (int64_t i = 0 ; i < numInnerLoops ; ++i) {
479+ for (int64_t i = 0 ; i < numNewChunks ; ++i) {
434480 // Compute the offset
435- Value inc = rewriter.create <arith::ConstantIndexOp>(loc, i * blockedChunk );
481+ Value inc = rewriter.create <arith::ConstantIndexOp>(loc, i * blockedChunkSize );
436482 Value incVec = rewriter.create <vector::SplatOp>(loc, indiceType, inc);
437483 Value offsetIndice = rewriter.create <arith::AddIOp>(loc, indice, incVec);
438484
@@ -443,8 +489,6 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
443489 }
444490 }
445491 } else {
446- convertedIndiceTypes = getUnrolledTypes (indiceVecTy, *targetShape);
447- convertedIndiceVec = pack (indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
448492 for (auto indice : convertedIndiceVec) {
449493 auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy,
450494 op.getSource (), indice);
@@ -468,15 +512,17 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
468512 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
469513 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
470514
471- // check if the tensor descriptor type is a 1d vector type
472- if (tdescTy.getRank () > 2 )
515+ if (!tdescTy.isScattered ())
473516 return failure ();
474517
475- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
476-
477518 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
478519 if (!targetShape)
479520 return failure ();
521+
522+ SmallVector<int64_t > targetMaskShape (*targetShape);
523+ int64_t originalChunkSize = tdescTy.getChunkSize ();
524+
525+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
480526
481527 Type elemTy = tdescTy.getElementType ();
482528 VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
@@ -489,25 +535,26 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
489535 SmallVector<Type> convertedMaskTypes;
490536 SmallVector<Value> convertedMasks;
491537
492- if (tdescTy. getRank () == 2 ) {
493- convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[ 0 ] );
494- SmallVector<Value> convertedMasks1D = pack (op. getMask (), convertedMaskTypes, (*targetShape)[ 0 ], loc, rewriter );
495- int64_t wholeChunk = tdescTy. getShape (). back ( );
496- int64_t blockedChunk = targetShape->back ();
497- int64_t numInnerLoops = wholeChunk / blockedChunk ;
538+ if (originalChunkSize > 1 ) {
539+ targetMaskShape. pop_back ( );
540+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape );
541+ SmallVector<Value> convertedMasks1D = pack (op. getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter );
542+ int64_t blockedChunkSize = targetShape->back ();
543+ int64_t numNewChunks = originalChunkSize/blockedChunkSize ;
498544
499545 for (auto mask : convertedMasks1D) {
500- for (int64_t i = 0 ; i < numInnerLoops ; ++i) {
546+ for (int64_t i = 0 ; i < numNewChunks ; ++i) {
501547 convertedMasks.push_back (mask);
502548 }
503549 }
550+ // This is to handle the transpose effect when chunkSize > 1.
504551 if (targetShape && targetShape->size () > 1 ) {
505552 std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
506553 newValueTy = valueTy.cloneWith (*targetShape, elemTy);
507554 }
508555 } else {
509- convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape );
510- convertedMasks = pack (op.getMask (), convertedMaskTypes, *targetShape , loc, rewriter);
556+ convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape );
557+ convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape , loc, rewriter);
511558 }
512559
513560 SmallVector<Value> newOps;
@@ -561,38 +608,38 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
561608 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue ().getType ());
562609 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
563610
564- // check if the tensor descriptor type is a 1d vector type
565- if (tdescTy.getRank () > 2 )
611+ if (!tdescTy.isScattered ())
566612 return failure ();
567613
568- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
569-
570614 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
571615 if (!targetShape)
572616 return failure ();
573-
617+
618+ SmallVector<int64_t > targetIndiceShape (*targetShape);
619+ int64_t originalChunkSize = tdescTy.getChunkSize ();
620+
621+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
622+
574623 SmallVector<Type> convertedTdescTypes =
575624 getUnrolledTypes (tdescTy, *targetShape);
576625 SmallVector<Value> convertedTdescs = pack (
577626 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
578627
579-
580628 SmallVector<Type> convertedMaskTypes;
581629 SmallVector<Value> convertedMasks;
582630
583- if (tdescTy.getRank () == 2 ) {
631+ if (originalChunkSize > 1 ) {
632+ int64_t blockedChunkSize = targetShape->back ();
633+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
584634 convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
585635 SmallVector<Value> convertedMasks1D = pack (op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
586- int64_t wholeChunk = tdescTy.getShape ().back ();
587- int64_t blockedChunk = targetShape->back ();
588- int64_t numInnerLoops = wholeChunk / blockedChunk;
589636
590637 for (auto mask : convertedMasks1D) {
591- for (int64_t i = 0 ; i < numInnerLoops ; ++i) {
638+ for (int64_t i = 0 ; i < numNewChunks ; ++i) {
592639 convertedMasks.push_back (mask);
593640 }
594641 }
595-
642+ // This is to handle the transpose effect when chunkSize > 1.
596643 std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
597644
598645 } else {
@@ -626,8 +673,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
626673 Location loc = op.getLoc ();
627674 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
628675
629- // check if the tensor descriptor type is a 1d vector type
630- if (tdescTy.getRank () > 2 )
676+ if (tdescTy.getRank () >2 )
677+ return failure ();
678+
679+ if (!tdescTy.isScattered ())
631680 return failure ();
632681
633682 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
@@ -644,18 +693,17 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
644693 SmallVector<Type> convertedOffsetTypes;
645694 SmallVector<Value> convertedOffsetVec;
646695 SmallVector<Value> newOps;
647-
648- if (tdescTy. getRank () == 2 ) {
696+ int64_t originalChunkSize = tdescTy. getChunkSize ();
697+ if (originalChunkSize > 1 ) {
649698 SmallVector<int64_t > shape1D (targetShape->begin (), targetShape->end () - 1 );
650699 convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
651700 SmallVector<Value> convertedOffsetVec1D = pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
652701
653- int64_t wholeChunk = tdescTy.getShape ().back ();
654- int64_t blockedChunk = targetShape->back ();
655- int64_t numInnerLoops = wholeChunk / blockedChunk;
702+ int64_t blockedChunkSize = targetShape->back ();
703+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
656704
657705 for (auto offset : convertedOffsetVec1D) {
658- for (int64_t i = 0 ; i < numInnerLoops ; ++i) {
706+ for (int64_t i = 0 ; i < numNewChunks ; ++i) {
659707 convertedOffsetVec.push_back (offset);
660708 }
661709 }
0 commit comments