Skip to content

Commit 55b0e2f

Browse files
committed
use constexpr for SLM tile size
Signed-off-by: dchigarev <[email protected]>
1 parent e2b6eb0 commit 55b0e2f

File tree

1 file changed

+30
-31
lines changed

1 file changed

+30
-31
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ static Value createFullMask(PatternRewriter &rewriter, Location loc,
6262
return res.getResult();
6363
}
6464

65+
// Max number of elements to load/store from SLM
66+
constexpr int64_t maxSLMTileSize = 32;
67+
6568
// Represents VNNI configuration for an operand.
6669
struct VnniConfig {
6770
int vnniFactor;
@@ -732,21 +735,19 @@ static SmallVector<Value> createScatterDescriptorTiles(
732735
PatternRewriter &rewriter, Location loc, Value flatMemref,
733736
ArrayRef<int64_t> loadShape2D, ArrayRef<int64_t> tileSize2D,
734737
ArrayRef<int64_t> memrefStrides, Value blockOffset) {
735-
int64_t maxLoadSize = 32;
736-
737738
assert(memrefStrides.size() == 2 && "Strides must be 2D");
738739
assert(memrefStrides[1] == 1 && "Only row-major strides are supported");
739740
assert(loadShape2D.size() == 2 && "Load shape must be 2D");
740-
assert(loadShape2D[0] * loadShape2D[1] % maxLoadSize == 0 &&
741+
assert(loadShape2D[0] * loadShape2D[1] % maxSLMTileSize == 0 &&
741742
"Load shape must be divisible by max load size");
742743
assert(tileSize2D.size() == 2 && "Descriptor tile must be 2D");
743-
assert(maxLoadSize % tileSize2D[1] == 0 &&
744+
assert(maxSLMTileSize % tileSize2D[1] == 0 &&
744745
"Descriptor tile must be divisible by max load size");
745746

746-
int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxLoadSize;
747+
int64_t numLoadsPerTile = tileSize2D[0] * tileSize2D[1] / maxSLMTileSize;
747748
// This indicates how many rows of a single tile (defined by tileSize2D) are
748749
// loaded per single load operation (single load loads exactly 32 elements).
749-
int64_t rowsPerLoad = maxLoadSize / tileSize2D[1];
750+
int64_t rowsPerLoad = maxSLMTileSize / tileSize2D[1];
750751
int64_t numColTiles = loadShape2D[1] / tileSize2D[1];
751752

752753
auto memrefType = dyn_cast<MemRefType>(flatMemref.getType());
@@ -757,7 +758,7 @@ static SmallVector<Value> createScatterDescriptorTiles(
757758
offsetShiftValues.push_back(SmallVector<int64_t>());
758759
for (int i = 0; i < rowsPerLoad; i++) {
759760
int64_t offset = i * memrefStrides[0];
760-
for (int j = 0; j < maxLoadSize / rowsPerLoad; j++)
761+
for (int j = 0; j < maxSLMTileSize / rowsPerLoad; j++)
761762
offsetShiftValues[colTile].push_back(offset + j +
762763
colTile * tileSize2D[1]);
763764
}
@@ -769,9 +770,10 @@ static SmallVector<Value> createScatterDescriptorTiles(
769770
rewriter, loc, SmallVector<int64_t>(32, skipPerLoad),
770771
rewriter.getIndexType());
771772

772-
auto offsetVecType = VectorType::get({maxLoadSize}, rewriter.getIndexType());
773+
auto offsetVecType =
774+
VectorType::get({maxSLMTileSize}, rewriter.getIndexType());
773775
auto descType = getTensorDescType(
774-
{maxLoadSize}, memrefType.getElementType(),
776+
{maxSLMTileSize}, memrefType.getElementType(),
775777
xegpu::ScatterTensorDescAttr::get(
776778
rewriter.getContext(), xegpu::MemorySpace::SLM, /*chunkSize=*/1));
777779

@@ -793,8 +795,9 @@ static SmallVector<Value> createScatterDescriptorTiles(
793795
.create<xegpu::CreateDescOp>(loc, descType, flatMemref, offsets0)
794796
.getResult();
795797
tiles.push_back(desc);
796-
for (int j = maxLoadSize; j < loadShape2D[0] * loadShape2D[1] / numColTiles;
797-
j += maxLoadSize) {
798+
for (int j = maxSLMTileSize;
799+
j < loadShape2D[0] * loadShape2D[1] / numColTiles;
800+
j += maxSLMTileSize) {
798801
auto newTile = rewriter
799802
.create<xegpu::UpdateOffsetOp>(
800803
loc, descType, tiles.back(), offsetPerLoad)
@@ -994,39 +997,37 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
994997
std::optional<VnniConfig> vnniConf = std::nullopt,
995998
DenseI64ArrayAttr transpose = nullptr,
996999
IntegerAttr transpose_bit = nullptr) {
997-
int64_t elementsPerLoad = 32;
998-
9991000
// Assume all tiles have the same shape.
10001001
auto tileType = cast<xegpu::TensorDescType>(loadTiles[0].getType());
10011002
assert(llvm::all_of(loadTiles,
10021003
[&](Value tile) { return tile.getType() == tileType; }) &&
10031004
"All load tiles must have the same type.");
10041005
assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D");
1005-
assert(tileType.getShape()[0] == elementsPerLoad &&
1006+
assert(tileType.getShape()[0] == maxSLMTileSize &&
10061007
"Scatter tiles must have 32 elements");
10071008
assert(!vnniConf && "VNNI not supported for scatter loads");
10081009
assert(!transpose && "Transpose is not supported for scatter loads");
10091010
assert(!transpose_bit && "Transpose is not supported for scatter loads");
10101011

10111012
int64_t totalLoadElems = tileType.getShape()[0] * loadTiles.size();
1012-
assert(totalLoadElems % elementsPerLoad == 0 &&
1013+
assert(totalLoadElems % maxSLMTileSize == 0 &&
10131014
"Total load size must be multiple of 32");
1014-
assert(tileShape[0] * tileShape[1] % elementsPerLoad == 0 &&
1015+
assert(tileShape[0] * tileShape[1] % maxSLMTileSize == 0 &&
10151016
"Tile shape must be multiple of 32");
10161017

1017-
int64_t loadsPerTile = tileShape[0] * tileShape[1] / elementsPerLoad;
1018-
int64_t totalNumLoads = totalLoadElems / elementsPerLoad;
1019-
auto mask = createFullMask(rewriter, loc, elementsPerLoad);
1018+
int64_t loadsPerTile = tileShape[0] * tileShape[1] / maxSLMTileSize;
1019+
int64_t totalNumLoads = totalLoadElems / maxSLMTileSize;
1020+
auto mask = createFullMask(rewriter, loc, maxSLMTileSize);
10201021

10211022
SmallVector<Value> result;
10221023
auto elementType = tileType.getElementType();
10231024
SmallVector<Attribute> accumValues(
1024-
loadsPerTile * elementsPerLoad,
1025+
loadsPerTile * maxSLMTileSize,
10251026
dyn_cast<Attribute>(rewriter.getZeroAttr(elementType)));
10261027

10271028
VectorType accumVectorType =
1028-
VectorType::get({loadsPerTile, elementsPerLoad}, elementType);
1029-
VectorType loadVectorType = VectorType::get({elementsPerLoad}, elementType);
1029+
VectorType::get({loadsPerTile, maxSLMTileSize}, elementType);
1030+
VectorType loadVectorType = VectorType::get({maxSLMTileSize}, elementType);
10301031

10311032
for (int64_t tileIdx = 0; tileIdx < totalNumLoads; tileIdx += loadsPerTile) {
10321033
// Accumulator vector for the current tile (its number of elements equals to
@@ -1049,7 +1050,7 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
10491050
loc, loadOp.getResult(), accumVector, SmallVector<int64_t>{loadIdx});
10501051
}
10511052

1052-
if (tileShape[1] == elementsPerLoad) {
1053+
if (tileShape[1] == maxSLMTileSize) {
10531054
// No need to reshape the accumulator vector.
10541055
result.push_back(accumVector);
10551056
continue;
@@ -1115,24 +1116,22 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc,
11151116
SmallVector<Value> &results,
11161117
ValueRange storeTiles,
11171118
xegpu::CachePolicyAttr hint) {
1118-
int64_t elementsPerStore = 32;
1119-
11201119
auto tileType = cast<xegpu::TensorDescType>(storeTiles[0].getType());
11211120
assert(llvm::all_of(storeTiles,
11221121
[&](Value tile) { return tile.getType() == tileType; }) &&
11231122
"All load tiles must have the same type.");
11241123
assert(tileType.getShape().size() == 1 && "Scatter tiles must be 1D");
1125-
assert(tileType.getShape()[0] == elementsPerStore &&
1124+
assert(tileType.getShape()[0] == maxSLMTileSize &&
11261125
"Scatter tiles must have 32 elements");
11271126

1128-
auto mask = createFullMask(rewriter, loc, elementsPerStore);
1127+
auto mask = createFullMask(rewriter, loc, maxSLMTileSize);
11291128
int64_t descIdx = 0;
11301129

11311130
for (auto vec : results) {
11321131
auto vecType = dyn_cast<VectorType>(vec.getType());
11331132
auto vecShape = vecType.getShape();
11341133
assert(vecShape.size() == 2 && "Expected 2D vector");
1135-
assert(vecShape[0] * vecShape[1] % elementsPerStore == 0 &&
1134+
assert(vecShape[0] * vecShape[1] % maxSLMTileSize == 0 &&
11361135
"Vector shape must be divisible by load size");
11371136

11381137
// Flatten the vector to 1D
@@ -1142,10 +1141,10 @@ static void storeScatterDescTiles(PatternRewriter &rewriter, Location loc,
11421141
vec);
11431142
// Extract slices of 32 size from 'flatVec' and store them
11441143
for (int64_t loadChunkIdx = 0; loadChunkIdx < vecShape[0] * vecShape[1];
1145-
loadChunkIdx += elementsPerStore) {
1144+
loadChunkIdx += maxSLMTileSize) {
11461145
auto toStore = rewriter.create<vector::ExtractStridedSliceOp>(
11471146
loc, flatVec, /*offsets=*/SmallVector<int64_t>({loadChunkIdx}),
1148-
/*sizes=*/SmallVector<int64_t>({elementsPerStore}),
1147+
/*sizes=*/SmallVector<int64_t>({maxSLMTileSize}),
11491148
/*strides=*/SmallVector<int64_t>({1}));
11501149
rewriter.create<xegpu::StoreScatterOp>(loc, toStore, storeTiles[descIdx],
11511150
/*mask=*/mask,
@@ -1901,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
19011900
}
19021901

19031902
// Extract SIMD sized sub-tiles
1904-
int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? 32 : 256;
1903+
int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? maxSLMTileSize : 256;
19051904
int64_t subTileCols = std::min(outputShape[1], maxSizeSIMD);
19061905
int64_t subTileRows =
19071906
std::min(outputShape[0], std::max(maxSizeSIMD / subTileCols, 1L));

0 commit comments

Comments
 (0)