@@ -407,37 +407,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
407407
408408 if (!tdescTy.isScattered ())
409409 return failure ();
410-
410+
411411 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
412412 if (!targetShape)
413413 return failure ();
414-
414+
415415 SmallVector<int64_t > targetIndiceShape (*targetShape);
416416 int64_t originalChunkSize = tdescTy.getChunkSize ();
417417 // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
418418 if (originalChunkSize > 1 )
419419 targetIndiceShape.pop_back ();
420420
421421 auto newTdescTy = getUnrolledTypes (tdescTy, *targetShape)[0 ];
422- SmallVector<Type> convertedIndiceTypes =
422+ SmallVector<Type> convertedIndiceTypes =
423423 getUnrolledTypes (indiceVecTy, targetIndiceShape);
424- SmallVector<Value> convertedIndiceVec =
424+ SmallVector<Value> convertedIndiceVec =
425425 pack (indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
426-
426+
427427 SmallVector<Value> newOps;
428428
429429 // more indices is need when chunkSize > 1. Since a big load from one
430430 // address could be break into multiple small loads.
431431 if (originalChunkSize > 1 ) {
432432 int64_t blockedChunkSize = targetShape->back ();
433- int64_t numNewChunks = originalChunkSize/ blockedChunkSize;
433+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
434434
435- for (auto [indice, indiceType] : llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
435+ for (auto [indice, indiceType] :
436+ llvm::zip (convertedIndiceVec, convertedIndiceTypes)) {
436437 for (int64_t i = 0 ; i < numNewChunks; ++i) {
437438 // Compute the offset
438- Value inc = rewriter.create <arith::ConstantIndexOp>(loc, i * blockedChunkSize);
439+ Value inc = rewriter.create <arith::ConstantIndexOp>(
440+ loc, i * blockedChunkSize);
439441 Value incVec = rewriter.create <vector::SplatOp>(loc, indiceType, inc);
440- Value offsetIndice = rewriter.create <arith::AddIOp>(loc, indice, incVec);
442+ Value offsetIndice =
443+ rewriter.create <arith::AddIOp>(loc, indice, incVec);
441444
442445 auto newOp = rewriter.create <xegpu::CreateDescOp>(
443446 loc, newTdescTy, op.getSource (), offsetIndice);
@@ -447,11 +450,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
447450 }
448451 } else {
449452 for (auto indice : convertedIndiceVec) {
450- auto newOp = rewriter.create <xegpu::CreateDescOp>(loc, newTdescTy,
451- op.getSource (), indice);
453+ auto newOp = rewriter.create <xegpu::CreateDescOp>(
454+ loc, newTdescTy, op.getSource (), indice);
452455 newOps.push_back (newOp);
453456 }
454- }
457+ }
455458
456459 Value castOp = unpack (newOps, tdescTy, *targetShape, loc, rewriter);
457460 rewriter.replaceOp (op, castOp);
@@ -471,11 +474,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
471474
472475 if (!tdescTy.isScattered ())
473476 return failure ();
474-
477+
475478 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
476479 if (!targetShape)
477480 return failure ();
478-
481+
479482 SmallVector<int64_t > targetMaskShape (*targetShape);
480483 int64_t originalChunkSize = tdescTy.getChunkSize ();
481484
@@ -489,29 +492,31 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
489492 SmallVector<Value> convertedTdescs = pack (
490493 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
491494
492- SmallVector<Type> convertedMaskTypes;
493- SmallVector<Value> convertedMasks;
495+ SmallVector<Type> convertedMaskTypes;
496+ SmallVector<Value> convertedMasks;
494497
495498 if (originalChunkSize > 1 ) {
496499 targetMaskShape.pop_back ();
497500 convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
498- SmallVector<Value> convertedMasks1D = pack (op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
501+ SmallVector<Value> convertedMasks1D = pack (
502+ op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
499503 int64_t blockedChunkSize = targetShape->back ();
500- int64_t numNewChunks = originalChunkSize/ blockedChunkSize;
504+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
501505
502506 for (auto mask : convertedMasks1D) {
503507 for (int64_t i = 0 ; i < numNewChunks; ++i) {
504508 convertedMasks.push_back (mask);
505509 }
506510 }
507- // This is to handle the transpose effect when chunkSize > 1.
511+ // This is to handle the transpose effect when chunkSize > 1.
508512 if (targetShape && targetShape->size () > 1 ) {
509513 std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
510514 newValueTy = valueTy.cloneWith (*targetShape, elemTy);
511515 }
512516 } else {
513517 convertedMaskTypes = getUnrolledTypes (maskTy, targetMaskShape);
514- convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape, loc, rewriter);
518+ convertedMasks = pack (op.getMask (), convertedMaskTypes, targetMaskShape,
519+ loc, rewriter);
515520 }
516521
517522 SmallVector<Value> newOps;
@@ -521,7 +526,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
521526 op.getL2HintAttr (), op.getL3HintAttr ());
522527 newOps.push_back (newOp);
523528 }
524-
529+
525530 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
526531 rewriter.replaceOp (op, castOp);
527532 return success ();
@@ -576,38 +581,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
576581 int64_t originalChunkSize = tdescTy.getChunkSize ();
577582
578583 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask ().getType ());
579-
584+
580585 SmallVector<Type> convertedTdescTypes =
581586 getUnrolledTypes (tdescTy, *targetShape);
582587 SmallVector<Value> convertedTdescs = pack (
583588 op.getTensorDesc (), convertedTdescTypes, *targetShape, loc, rewriter);
584589
585- SmallVector<Type> convertedMaskTypes;
586- SmallVector<Value> convertedMasks;
590+ SmallVector<Type> convertedMaskTypes;
591+ SmallVector<Value> convertedMasks;
587592
588593 if (originalChunkSize > 1 ) {
589594 int64_t blockedChunkSize = targetShape->back ();
590- int64_t numNewChunks = originalChunkSize/ blockedChunkSize;
595+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
591596 convertedMaskTypes = getUnrolledTypes (maskTy, (*targetShape)[0 ]);
592- SmallVector<Value> convertedMasks1D = pack (op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
597+ SmallVector<Value> convertedMasks1D = pack (
598+ op.getMask (), convertedMaskTypes, (*targetShape)[0 ], loc, rewriter);
593599
594600 for (auto mask : convertedMasks1D) {
595601 for (int64_t i = 0 ; i < numNewChunks; ++i) {
596602 convertedMasks.push_back (mask);
597603 }
598604 }
599- // This is to handle the transpose effect when chunkSize > 1.
605+ // This is to handle the transpose effect when chunkSize > 1.
600606 std::swap ((*targetShape)[0 ], (*targetShape)[1 ]);
601607
602608 } else {
603609 convertedMaskTypes = getUnrolledTypes (maskTy, *targetShape);
604- convertedMasks = pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
610+ convertedMasks =
611+ pack (op.getMask (), convertedMaskTypes, *targetShape, loc, rewriter);
605612 }
606613
607614 SmallVector<Type> convertedValTypes =
608615 getUnrolledTypes (valueTy, *targetShape);
609616 SmallVector<Value> convertedValues =
610- pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
617+ pack (op.getValue (), convertedValTypes, *targetShape, loc, rewriter);
611618
612619 for (size_t i = 0 ; i < convertedValues.size (); ++i) {
613620 Value v = convertedValues[i];
@@ -630,7 +637,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
630637 Location loc = op.getLoc ();
631638 xegpu::TensorDescType tdescTy = op.getTensorDescType ();
632639
633- if (tdescTy.getRank () >2 )
640+ if (tdescTy.getRank () > 2 )
634641 return failure ();
635642
636643 if (!tdescTy.isScattered ())
@@ -652,12 +659,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
652659 SmallVector<Value> newOps;
653660 int64_t originalChunkSize = tdescTy.getChunkSize ();
654661 if (originalChunkSize > 1 ) {
655- SmallVector<int64_t > shape1D (targetShape->begin (), targetShape->end () - 1 );
662+ SmallVector<int64_t > shape1D (targetShape->begin (),
663+ targetShape->end () - 1 );
656664 convertedOffsetTypes = getUnrolledTypes (offsetVecTy, shape1D);
657- SmallVector<Value> convertedOffsetVec1D = pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
665+ SmallVector<Value> convertedOffsetVec1D =
666+ pack (offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
658667
659668 int64_t blockedChunkSize = targetShape->back ();
660- int64_t numNewChunks = originalChunkSize/ blockedChunkSize;
669+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
661670
662671 for (auto offset : convertedOffsetVec1D) {
663672 for (int64_t i = 0 ; i < numNewChunks; ++i) {
@@ -667,8 +676,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
667676
668677 } else {
669678 convertedOffsetTypes = getUnrolledTypes (offsetVecTy, *targetShape);
670- convertedOffsetVec = pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
671- }
679+ convertedOffsetVec =
680+ pack (offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
681+ }
672682
673683 for (auto [t, o] : llvm::zip (convertedTdesc, convertedOffsetVec)) {
674684 auto newOp =
0 commit comments