3030#include " llvm/ADT/STLExtras.h"
3131#include " llvm/ADT/SmallVector.h"
3232#include " llvm/ADT/TypeSwitch.h"
33+ #include " llvm/Support/Casting.h"
3334#include " llvm/Support/Debug.h"
3435#include " llvm/Support/InterleavedRange.h"
3536#include " llvm/Support/raw_ostream.h"
@@ -103,6 +104,7 @@ struct LayoutInfo {
103104private:
104105 LaneLayout laneLayout;
105106 LaneData laneData;
107+ xegpu::LayoutAttr layoutAttr;
106108
107109public:
108110 LayoutInfo () = default ;
@@ -186,7 +188,7 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
186188// / Helper Function to get the default layout for uniform values like constants.
187189// / For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
188190// / For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
189- static LayoutInfo getDefaultLayoutInfo (unsigned rank) {
191+ static LayoutInfo getDefaultSIMTLayoutInfo (unsigned rank) {
190192 assert ((rank == 1 || rank == 2 ) && " Expected 1D or 2D vector." );
191193 if (rank == 1 )
192194 return LayoutInfo (LaneLayout ({xegpu::targetinfo::subgroupSize}),
@@ -196,7 +198,7 @@ static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
196198}
197199
198200// / Helper to get the default layout for a vector type.
199- static LayoutInfo getDefaultLayoutInfo (VectorType vectorTy) {
201+ static LayoutInfo getDefaultSIMTLayoutInfo (VectorType vectorTy) {
200202 // Expecting a 1D or 2D vector.
201203 assert ((vectorTy.getRank () == 1 || vectorTy.getRank () == 2 ) &&
202204 " Expected 1D or 2D vector." );
@@ -205,7 +207,7 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
205207 " Expected int or float element type." );
206208 // If the rank is 1, then return default layout for 1D vector.
207209 if (vectorTy.getRank () == 1 )
208- return getDefaultLayoutInfo (1 );
210+ return getDefaultSIMTLayoutInfo (1 );
209211 // Packing factor is determined by the element type bitwidth.
210212 int packingFactor = 1 ;
211213 unsigned bitwidth = vectorTy.getElementType ().getIntOrFloatBitWidth ();
@@ -221,8 +223,8 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
221223// / `packedSizeInBitsForDefault`
222224// / * For B operand, the data must be packed in minimum
223225// / `packedSizeInBitsForDpasB`
224- static LayoutInfo getLayoutInfoForDPASOperand (VectorType vectorTy,
225- unsigned operandNum) {
226+ static LayoutInfo getSIMTLayoutInfoForDPASOperand (VectorType vectorTy,
227+ unsigned operandNum) {
226228 Type elementTy = vectorTy.getElementType ();
227229 assert (elementTy.isIntOrFloat () &&
228230 " Expected int or float type in DPAS operands" );
@@ -237,7 +239,7 @@ static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
237239 return LayoutInfo (layout, data);
238240 }
239241 // Otherwise, return the default layout for the vector type.
240- return getDefaultLayoutInfo (vectorTy);
242+ return getDefaultSIMTLayoutInfo (vectorTy);
241243}
242244
243245// ===----------------------------------------------------------------------===//
@@ -360,17 +362,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
360362 // All other ops.
361363 .Default ([&](Operation *op) {
362364 for (const LayoutInfoLattice *r : results) {
363- for (LayoutInfoLattice *operand : operands ) {
364- // Propagate the layout of the result to the operand.
365- if (r-> getValue (). isAssigned ())
365+ if (r-> getValue (). isAssigned () ) {
366+ for (LayoutInfoLattice *operand : operands) {
367+ // Propagate the layout of the result to the operand.
366368 meet (operand, *r);
369+ }
367370 }
368371 }
369372 });
370373 // Add a dependency from each result to program point after the operation.
371- for (const LayoutInfoLattice *r : results) {
374+ for (const LayoutInfoLattice *r : results)
372375 addDependency (const_cast <LayoutInfoLattice *>(r), getProgramPointAfter (op));
373- }
376+
374377 return success ();
375378}
376379
@@ -380,7 +383,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
380383 // Here we assign the default layout to the tensor descriptor operand of
381384 // prefetch.
382385 auto tdescTy = prefetch.getTensorDescType ();
383- auto prefetchLayout = getDefaultLayoutInfo (
386+ auto prefetchLayout = getDefaultSIMTLayoutInfo (
384387 VectorType::get (tdescTy.getShape (), tdescTy.getElementType ()));
385388 // Propagate the layout to the source tensor descriptor.
386389 propagateIfChanged (operands[0 ], operands[0 ]->meet (prefetchLayout));
@@ -395,11 +398,13 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
395398 if (!resultLayout.isAssigned ())
396399 return ;
397400 // We only consider 2D -> 1D reductions at this point.
398- assert (resultLayout.getLayout ().size () == 1 &&
399- " Expected 1D layout for reduction result." );
401+ if (resultLayout.getLayout ().size () != 1 ) {
402+ reduction.emitWarning (" Expected 1D layout for reduction result. " );
403+ return ;
404+ }
400405 // Given that the result is 1D, the layout of the operand should be 2D with
401406 // default layout.
402- LayoutInfo operandLayout = getDefaultLayoutInfo (2 );
407+ LayoutInfo operandLayout = getDefaultSIMTLayoutInfo (2 );
403408 propagateIfChanged (operands[0 ], operands[0 ]->meet (operandLayout));
404409 // Accumulator should have the same layout as the result.
405410 propagateIfChanged (operands[1 ], operands[1 ]->meet (resultLayout));
@@ -425,22 +430,23 @@ void LayoutInfoPropagation::visitDpasOp(
425430 ArrayRef<const LayoutInfoLattice *> results) {
426431 VectorType aTy = dpas.getLhsType ();
427432 VectorType bTy = dpas.getRhsType ();
428- propagateIfChanged (operands[ 0 ],
429- operands[0 ]->meet (getLayoutInfoForDPASOperand (aTy, 0 )));
430- propagateIfChanged (operands[ 1 ],
431- operands[1 ]->meet (getLayoutInfoForDPASOperand (bTy, 1 )));
433+ propagateIfChanged (
434+ operands[ 0 ], operands[0 ]->meet (getSIMTLayoutInfoForDPASOperand (aTy, 0 )));
435+ propagateIfChanged (
436+ operands[ 1 ], operands[1 ]->meet (getSIMTLayoutInfoForDPASOperand (bTy, 1 )));
432437 if (operands.size () > 2 ) {
433438 VectorType cTy = dpas.getAccType ();
434- propagateIfChanged (operands[2 ],
435- operands[2 ]->meet (getLayoutInfoForDPASOperand (cTy, 2 )));
439+ propagateIfChanged (
440+ operands[2 ],
441+ operands[2 ]->meet (getSIMTLayoutInfoForDPASOperand (cTy, 2 )));
436442 }
437443}
438444
439445// / Set the layout for the value and tensor descriptor operands in StoreNdOp.
440446void LayoutInfoPropagation::visitStoreNdOp (
441447 xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
442448 ArrayRef<const LayoutInfoLattice *> results) {
443- LayoutInfo storeLayout = getDefaultLayoutInfo (store.getValueType ());
449+ LayoutInfo storeLayout = getDefaultSIMTLayoutInfo (store.getValueType ());
444450 // Both operands should have the same layout
445451 for (LayoutInfoLattice *operand : operands) {
446452 propagateIfChanged (operand, operand->meet (storeLayout));
@@ -539,7 +545,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
539545 tensorDescLayout = valueLayout.getTransposedLayout ({1 , 0 });
540546 }
541547 // Mask operand should have 1D default layout.
542- LayoutInfo maskLayout = getDefaultLayoutInfo (1 );
548+ LayoutInfo maskLayout = getDefaultSIMTLayoutInfo (1 );
543549 // Propagate the new layout to the tensor descriptor operand.
544550 propagateIfChanged (operands[0 ], operands[0 ]->meet (tensorDescLayout));
545551 // Propagate the new layout to the mask operand.
@@ -556,7 +562,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
556562 if (!descLayout.isAssigned ())
557563 return ;
558564 // For offset operand propagate 1D default layout.
559- LayoutInfo layout = getDefaultLayoutInfo (1 );
565+ LayoutInfo layout = getDefaultSIMTLayoutInfo (1 );
560566 propagateIfChanged (operands[1 ], operands[1 ]->meet (layout));
561567}
562568
@@ -575,7 +581,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
575581 " Expected the first dimension of 2D tensor descriptor to be equal to "
576582 " subgroup size." );
577583
578- LayoutInfo valueLayout = getDefaultLayoutInfo (storeScatter.getValueType ());
584+ LayoutInfo valueLayout =
585+ getDefaultSIMTLayoutInfo (storeScatter.getValueType ());
579586 LayoutInfo storeScatterLayout = valueLayout;
580587 if (storeScatter.getTranspose ()) {
581588 // StoreScatteOp allows transpose effect. However, at the stage of this
@@ -590,7 +597,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
590597 // Propagate the tensor descriptor layout.
591598 propagateIfChanged (operands[1 ], operands[1 ]->meet (storeScatterLayout));
592599 // Use default 1D layout for mask operand.
593- LayoutInfo maskLayout = getDefaultLayoutInfo (1 );
600+ LayoutInfo maskLayout = getDefaultSIMTLayoutInfo (1 );
594601 propagateIfChanged (operands[2 ], operands[2 ]->meet (maskLayout));
595602}
596603
0 commit comments