@@ -62,6 +62,9 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262 return res.getResult ();
6363}
6464
65+ // Max number of elements to load/store from SLM
66+ constexpr int64_t maxSLMTileSize = 32 ;
67+
6568// Represents VNNI configuration for an operand.
6669struct VnniConfig {
6770 int vnniFactor;
@@ -732,21 +735,19 @@ static SmallVector<Value> createScatterDescriptorTiles(
732735 PatternRewriter &rewriter, Location loc, Value flatMemref,
733736 ArrayRef<int64_t > loadShape2D, ArrayRef<int64_t > tileSize2D,
734737 ArrayRef<int64_t > memrefStrides, Value blockOffset) {
735- int64_t maxLoadSize = 32 ;
736-
737738 assert (memrefStrides.size () == 2 && " Strides must be 2D" );
738739 assert (memrefStrides[1 ] == 1 && " Only row-major strides are supported" );
739740 assert (loadShape2D.size () == 2 && " Load shape must be 2D" );
740- assert (loadShape2D[0 ] * loadShape2D[1 ] % maxLoadSize == 0 &&
741+ assert (loadShape2D[0 ] * loadShape2D[1 ] % maxSLMTileSize == 0 &&
741742 " Load shape must be divisible by max load size" );
742743 assert (tileSize2D.size () == 2 && " Descriptor tile must be 2D" );
743- assert (maxLoadSize % tileSize2D[1 ] == 0 &&
744+ assert (maxSLMTileSize % tileSize2D[1 ] == 0 &&
744745 " Descriptor tile must be divisible by max load size" );
745746
746- int64_t numLoadsPerTile = tileSize2D[0 ] * tileSize2D[1 ] / maxLoadSize ;
747+ int64_t numLoadsPerTile = tileSize2D[0 ] * tileSize2D[1 ] / maxSLMTileSize ;
747748 // This indicates how many rows of a single tile (defined by tileSize2D) are
748749 // loaded per single load operation (single load loads exactly 32 elements).
749- int64_t rowsPerLoad = maxLoadSize / tileSize2D[1 ];
750+ int64_t rowsPerLoad = maxSLMTileSize / tileSize2D[1 ];
750751 int64_t numColTiles = loadShape2D[1 ] / tileSize2D[1 ];
751752
752753 auto memrefType = dyn_cast<MemRefType>(flatMemref.getType ());
@@ -757,7 +758,7 @@ static SmallVector<Value> createScatterDescriptorTiles(
757758 offsetShiftValues.push_back (SmallVector<int64_t >());
758759 for (int i = 0 ; i < rowsPerLoad; i++) {
759760 int64_t offset = i * memrefStrides[0 ];
760- for (int j = 0 ; j < maxLoadSize / rowsPerLoad; j++)
761+ for (int j = 0 ; j < maxSLMTileSize / rowsPerLoad; j++)
761762 offsetShiftValues[colTile].push_back (offset + j +
762763 colTile * tileSize2D[1 ]);
763764 }
@@ -769,9 +770,10 @@ static SmallVector<Value> createScatterDescriptorTiles(
769770 rewriter, loc, SmallVector<int64_t >(32 , skipPerLoad),
770771 rewriter.getIndexType ());
771772
772- auto offsetVecType = VectorType::get ({maxLoadSize}, rewriter.getIndexType ());
773+ auto offsetVecType =
774+ VectorType::get ({maxSLMTileSize}, rewriter.getIndexType ());
773775 auto descType = getTensorDescType (
774- {maxLoadSize }, memrefType.getElementType (),
776+ {maxSLMTileSize }, memrefType.getElementType (),
775777 xegpu::ScatterTensorDescAttr::get (
776778 rewriter.getContext (), xegpu::MemorySpace::SLM, /* chunkSize=*/ 1 ));
777779
@@ -793,8 +795,9 @@ static SmallVector<Value> createScatterDescriptorTiles(
793795 .create <xegpu::CreateDescOp>(loc, descType, flatMemref, offsets0)
794796 .getResult ();
795797 tiles.push_back (desc);
796- for (int j = maxLoadSize; j < loadShape2D[0 ] * loadShape2D[1 ] / numColTiles;
797- j += maxLoadSize) {
798+ for (int j = maxSLMTileSize;
799+ j < loadShape2D[0 ] * loadShape2D[1 ] / numColTiles;
800+ j += maxSLMTileSize) {
798801 auto newTile = rewriter
799802 .create <xegpu::UpdateOffsetOp>(
800803 loc, descType, tiles.back (), offsetPerLoad)
@@ -994,39 +997,37 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
994997 std::optional<VnniConfig> vnniConf = std::nullopt ,
995998 DenseI64ArrayAttr transpose = nullptr ,
996999 IntegerAttr transpose_bit = nullptr ) {
997- int64_t elementsPerLoad = 32 ;
998-
9991000 // Assume all tiles have the same shape.
10001001 auto tileType = cast<xegpu::TensorDescType>(loadTiles[0 ].getType ());
10011002 assert (llvm::all_of (loadTiles,
10021003 [&](Value tile) { return tile.getType () == tileType; }) &&
10031004 " All load tiles must have the same type." );
10041005 assert (tileType.getShape ().size () == 1 && " Scatter tiles must be 1D" );
1005- assert (tileType.getShape ()[0 ] == elementsPerLoad &&
1006+ assert (tileType.getShape ()[0 ] == maxSLMTileSize &&
10061007 " Scatter tiles must have 32 elements" );
10071008 assert (!vnniConf && " VNNI not supported for scatter loads" );
10081009 assert (!transpose && " Transpose is not supported for scatter loads" );
10091010 assert (!transpose_bit && " Transpose is not supported for scatter loads" );
10101011
10111012 int64_t totalLoadElems = tileType.getShape ()[0 ] * loadTiles.size ();
1012- assert (totalLoadElems % elementsPerLoad == 0 &&
1013+ assert (totalLoadElems % maxSLMTileSize == 0 &&
10131014 " Total load size must be multiple of 32" );
1014- assert (tileShape[0 ] * tileShape[1 ] % elementsPerLoad == 0 &&
1015+ assert (tileShape[0 ] * tileShape[1 ] % maxSLMTileSize == 0 &&
10151016 " Tile shape must be multiple of 32" );
10161017
1017- int64_t loadsPerTile = tileShape[0 ] * tileShape[1 ] / elementsPerLoad ;
1018- int64_t totalNumLoads = totalLoadElems / elementsPerLoad ;
1019- auto mask = createFullMask (rewriter, loc, elementsPerLoad );
1018+ int64_t loadsPerTile = tileShape[0 ] * tileShape[1 ] / maxSLMTileSize ;
1019+ int64_t totalNumLoads = totalLoadElems / maxSLMTileSize ;
1020+ auto mask = createFullMask (rewriter, loc, maxSLMTileSize );
10201021
10211022 SmallVector<Value> result;
10221023 auto elementType = tileType.getElementType ();
10231024 SmallVector<Attribute> accumValues (
1024- loadsPerTile * elementsPerLoad ,
1025+ loadsPerTile * maxSLMTileSize ,
10251026 dyn_cast<Attribute>(rewriter.getZeroAttr (elementType)));
10261027
10271028 VectorType accumVectorType =
1028- VectorType::get ({loadsPerTile, elementsPerLoad }, elementType);
1029- VectorType loadVectorType = VectorType::get ({elementsPerLoad }, elementType);
1029+ VectorType::get ({loadsPerTile, maxSLMTileSize }, elementType);
1030+ VectorType loadVectorType = VectorType::get ({maxSLMTileSize }, elementType);
10301031
10311032 for (int64_t tileIdx = 0 ; tileIdx < totalNumLoads; tileIdx += loadsPerTile) {
10321033 // Accumulator vector for the current tile (its number of elements equals to
@@ -1049,7 +1050,7 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
10491050 loc, loadOp.getResult (), accumVector, SmallVector<int64_t >{loadIdx});
10501051 }
10511052
1052- if (tileShape[1 ] == elementsPerLoad ) {
1053+ if (tileShape[1 ] == maxSLMTileSize ) {
10531054 // No need to reshape the accumulator vector.
10541055 result.push_back (accumVector);
10551056 continue ;
@@ -1115,24 +1116,22 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc,
11151116 SmallVector<Value> &results,
11161117 ValueRange storeTiles,
11171118 xegpu::CachePolicyAttr hint) {
1118- int64_t elementsPerStore = 32 ;
1119-
11201119 auto tileType = cast<xegpu::TensorDescType>(storeTiles[0 ].getType ());
11211120 assert (llvm::all_of (storeTiles,
11221121 [&](Value tile) { return tile.getType () == tileType; }) &&
11231122 " All load tiles must have the same type." );
11241123 assert (tileType.getShape ().size () == 1 && " Scatter tiles must be 1D" );
1125- assert (tileType.getShape ()[0 ] == elementsPerStore &&
1124+ assert (tileType.getShape ()[0 ] == maxSLMTileSize &&
11261125 " Scatter tiles must have 32 elements" );
11271126
1128- auto mask = createFullMask (rewriter, loc, elementsPerStore );
1127+ auto mask = createFullMask (rewriter, loc, maxSLMTileSize );
11291128 int64_t descIdx = 0 ;
11301129
11311130 for (auto vec : results) {
11321131 auto vecType = dyn_cast<VectorType>(vec.getType ());
11331132 auto vecShape = vecType.getShape ();
11341133 assert (vecShape.size () == 2 && " Expected 2D vector" );
1135- assert (vecShape[0 ] * vecShape[1 ] % elementsPerStore == 0 &&
1134+ assert (vecShape[0 ] * vecShape[1 ] % maxSLMTileSize == 0 &&
11361135 " Vector shape must be divisible by load size" );
11371136
11381137 // Flatten the vector to 1D
@@ -1142,10 +1141,10 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc,
11421141 vec);
11431142 // Extract slices of 32 size from 'flatVec' and store them
11441143 for (int64_t loadChunkIdx = 0 ; loadChunkIdx < vecShape[0 ] * vecShape[1 ];
1145- loadChunkIdx += elementsPerStore ) {
1144+ loadChunkIdx += maxSLMTileSize ) {
11461145 auto toStore = rewriter.create <vector::ExtractStridedSliceOp>(
11471146 loc, flatVec, /* offsets=*/ SmallVector<int64_t >({loadChunkIdx}),
1148- /* sizes=*/ SmallVector<int64_t >({elementsPerStore }),
1147+ /* sizes=*/ SmallVector<int64_t >({maxSLMTileSize }),
11491148 /* strides=*/ SmallVector<int64_t >({1 }));
11501149 rewriter.create <xegpu::StoreScatterOp>(loc, toStore, storeTiles[descIdx],
11511150 /* mask=*/ mask,
@@ -1901,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
19011900 }
19021901
19031902 // Extract SIMD sized sub-tiles
1904- int64_t maxSizeSIMD = utils::hasSharedMemSpace (output) ? 32 : 256 ;
1903+ int64_t maxSizeSIMD = utils::hasSharedMemSpace (output) ? maxSLMTileSize : 256 ;
19051904 int64_t subTileCols = std::min (outputShape[1 ], maxSizeSIMD);
19061905 int64_t subTileRows =
19071906 std::min (outputShape[0 ], std::max (maxSizeSIMD / subTileCols, 1L ));
0 commit comments