@@ -204,6 +204,28 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
204204 using Lattice::Lattice;
205205};
206206
207+ // / Helper Function to find a proper instruction multiple for the user-supplied
208+ // / sg-level data shape. `candidates` are uArch allowed shapes.
209+ // / `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
210+ template <typename T>
211+ int getLargestDivisor (T dim, ArrayRef<T> candidates,
212+ ArrayRef<T> candidateMultiples = {}) {
213+ static_assert (std::is_integral<T>::value, " T must be an integer type" );
214+ int largest = -1 ;
215+ SmallVector<T> multiples = {1 };
216+ if (!candidateMultiples.empty ())
217+ multiples =
218+ SmallVector<T>(candidateMultiples.begin (), candidateMultiples.end ());
219+ for (T candidate : candidates) {
220+ for (T multiple : multiples) {
221+ int value = static_cast <int >(candidate * multiple);
222+ if (value != 0 && dim % value == 0 && value > largest)
223+ largest = value;
224+ }
225+ }
226+ return largest;
227+ }
228+
207229// / Helper Functions to get default layouts. A `default layout` is a layout that
208230// / is assigned to a value when the layout is not fixed by some anchor operation
209231// / (like DPAS).
@@ -482,12 +504,23 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
482504 if (!blockWHC)
483505 prefetch.emitWarning (" No known block params found for the element type." );
484506 auto [bWidth, bHeight, bCount] = blockWHC.value ();
485-
486507 SmallVector<int > instData;
508+ int instWidth = getLargestDivisor (
509+ static_cast <int >(tdescTy.getDimSize (tdescTy.getRank () - 1 )), bWidth,
510+ bCount);
511+ if (instWidth == -1 )
512+ prefetch.emitWarning (
513+ " No suitable instruction multiple found for the given shape." );
487514 if (tdescTy.getRank () == 1 )
488- instData = {bWidth.back () * bCount.back ()};
489- else
490- instData = {bHeight.back (), bWidth.back () * bCount.back ()};
515+ instData = {instWidth};
516+ else {
517+ int instHeight = getLargestDivisor (
518+ static_cast <int >(tdescTy.getDimSize (tdescTy.getRank () - 2 )), bHeight);
519+ if (instHeight == -1 )
520+ prefetch.emitWarning (
521+ " No suitable instruction multiple found for the given shape." );
522+ instData = {instHeight, instWidth};
523+ }
491524 auto prefetchLayout = getDefaultSIMTLayoutInfo (
492525 tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize ());
493526 // Propagate the layout to the source tensor descriptor.
@@ -597,10 +630,22 @@ void LayoutInfoPropagation::visitDpasOp(
597630 const auto *uArchInstruction =
598631 dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction (
599632 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
633+
634+ const unsigned dataALen = aTy.getShape ().front ();
635+ auto supportedALen = uArchInstruction->getSupportedM (aTy.getElementType ());
600636 const int maxALen =
601- uArchInstruction->getSupportedM (aTy.getElementType ()).back ();
637+ getLargestDivisor (dataALen, ArrayRef<unsigned >(supportedALen));
638+ if (maxALen == -1 )
639+ dpas.emitWarning (
640+ " No suitable instruction multiple found for the given shape." );
641+
642+ const unsigned dataBLen = bTy.getShape ().back ();
643+ auto supportedBLen = uArchInstruction->getSupportedK (bTy.getElementType ());
602644 const int maxBLen =
603- uArchInstruction->getSupportedK (bTy.getElementType ()).back ();
645+ getLargestDivisor (dataBLen, ArrayRef<unsigned >(supportedBLen));
646+ if (maxBLen == -1 )
647+ dpas.emitWarning (
648+ " No suitable instruction multiple found for the given shape." );
604649 SmallVector<int > instDataA = {maxALen, subgroupSize};
605650 SmallVector<int > instDataB = {subgroupSize, maxBLen};
606651
@@ -614,8 +659,13 @@ void LayoutInfoPropagation::visitDpasOp(
614659 uArchInstruction->getPackedFormatBitSizeB ())));
615660 if (operands.size () > 2 ) {
616661 VectorType cTy = dpas.getAccType ();
662+ const unsigned dataCLen = bTy.getShape ().back ();
663+ auto supportedCLen = uArchInstruction->getSupportedN (bTy.getElementType ());
617664 const int maxCLen =
618- uArchInstruction->getSupportedN (bTy.getElementType ()).back ();
665+ getLargestDivisor (dataCLen, ArrayRef<unsigned >(supportedCLen));
666+ if (maxCLen == -1 )
667+ dpas.emitWarning (
668+ " No suitable instruction multiple found for the given shape." );
619669 SmallVector<int > instDataC = {maxALen, maxCLen};
620670 propagateIfChanged (operands[2 ],
621671 operands[2 ]->meet (getSIMTLayoutInfoForDPASOperand (
@@ -634,16 +684,29 @@ void LayoutInfoPropagation::visitStoreNdOp(
634684 dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
635685 uArch->getInstruction (
636686 xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
687+ VectorType dataTy = store.getValueType ();
637688 auto blockWHC = uArchInstruction->getBlockWidthHeightCount (
638689 store.getValueType ().getElementType ());
639690 if (!blockWHC)
640691 store.emitWarning (" No known block params found for the element type." );
641692 auto [bWidth, bHeight, bCount] = blockWHC.value ();
642693 SmallVector<int > instData;
643- if (store.getValueType ().getRank () == 1 )
644- instData = {bWidth.back () * bCount.back ()};
645- else
646- instData = {bHeight.back (), bWidth.back () * bCount.back ()};
694+ int instWidth = getLargestDivisor (
695+ static_cast <int >(dataTy.getDimSize (dataTy.getRank () - 1 )), bWidth,
696+ bCount);
697+ if (instWidth == -1 )
698+ store.emitWarning (
699+ " No suitable instruction multiple found for the given shape." );
700+ if (dataTy.getRank () == 1 )
701+ instData = {instWidth};
702+ else {
703+ int instHeight = getLargestDivisor (
704+ static_cast <int >(dataTy.getDimSize (dataTy.getRank () - 2 )), bHeight);
705+ if (instHeight == -1 )
706+ store.emitWarning (
707+ " No suitable instruction multiple found for the given shape." );
708+ instData = {instHeight, instWidth};
709+ }
647710 LayoutInfo storeLayout =
648711 getDefaultSIMTLayoutInfo (store.getValueType (), uArch, instData,
649712 uArchInstruction->getPackedFormatBitSize ());
0 commit comments