Skip to content

Commit 6ec8419

Browse files
authored
uArch API: Use element bit width instead of byte width (#1099)
uArch API: Use element bit width instead of byte width to allow sub byte data types.
1 parent bf020c3 commit 6ec8419

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

include/imex/Utils/XeArch.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ class XeuArchInterface {
9999

100100
mlir::LogicalResult verify2dBlockRestriction(mlir::Operation *op, int width,
101101
int height, int array_len,
102-
int elemTyByteWidth,
102+
int elemTyBitWidth,
103103
bool transpose, bool vnni,
104104
LoadStore2DConfig configParams,
105105
bool isLoad = true);
106106

107107
virtual mlir::LogicalResult
108108
verify2dPrefetchRestriction(mlir::Operation *op, int width, int height,
109-
int array_len, int elemTyByteWidth,
109+
int array_len, int elemTyBitWidth,
110110
LoadStore2DConfig configParams) = 0;
111111
mlir::LogicalResult isLegalDpasOp(mlir::Operation *op);
112112

@@ -192,10 +192,10 @@ class XePVCuArch : public XeuArchInterface {
192192

193193
mlir::LogicalResult
194194
verify2dPrefetchRestriction(mlir::Operation *op, int width, int height,
195-
int array_len, int elemTyByteWidth,
195+
int array_len, int elemTyBitWidth,
196196
LoadStore2DConfig configParams) override {
197197
return verify2dBlockRestriction(op, width, height, array_len,
198-
elemTyByteWidth, false, false, configParams,
198+
elemTyBitWidth, false, false, configParams,
199199
true);
200200
}
201201
virtual mlir::FailureOr<LoadStore2DConfig>

lib/Utils/XeArch.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,16 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
256256
return mlir::success();
257257
}
258258

259+
static int getInMemoryBitWidth(int elemTyBitWidth) {
260+
if (elemTyBitWidth == 19)
261+
return 32; // TF32 is stored in 32 bits;
262+
// TODO: add support for other loosely packed types
263+
return elemTyBitWidth;
264+
}
265+
259266
mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
260267
mlir::Operation *op, int width, int height, int array_len,
261-
int elemTyByteWidth, bool transpose, bool vnni,
268+
int elemTyBitWidth, bool transpose, bool vnni,
262269
LoadStore2DConfig configParams, bool isLoad) {
263270

264271
if (!llvm::isPowerOf2_32(array_len))
@@ -271,15 +278,15 @@ mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
271278

272279
if ((width < configParams.blockWidth.min ||
273280
width > configParams.blockWidth.max ||
274-
(width * elemTyByteWidth) % 4 != 0))
281+
(width * getInMemoryBitWidth(elemTyBitWidth) / 8) % 4 != 0))
275282
return op->emitOpError()
276283
<< "Invalid width size for 2D block load. "
277284
<< "The specification expects the value to "
278285
<< "be in range [" << configParams.blockWidth.min << ", "
279286
<< configParams.blockWidth.max << "], and "
280287
<< "the total data size (width * elemTyBytes) to be multiple of 4. "
281-
<< "Given width: " << width
282-
<< " and data size: " << width * elemTyByteWidth;
288+
<< "Given width: " << width << " and data size: "
289+
<< width * getInMemoryBitWidth(elemTyBitWidth) / 8;
283290

284291
if (height < configParams.blockHeight.min ||
285292
height > configParams.blockHeight.max)
@@ -288,7 +295,8 @@ mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
288295
<< "be in range [" << configParams.blockHeight.min
289296
<< ", " << configParams.blockHeight.max << "].";
290297

291-
int GRFSize = width * height * array_len * elemTyByteWidth;
298+
int GRFSize =
299+
width * height * array_len * getInMemoryBitWidth(elemTyBitWidth) / 8;
292300
int supportedSize =
293301
isLoad ? configParams.GRFDataSize.load : configParams.GRFDataSize.store;
294302

@@ -331,11 +339,10 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
331339
auto width = tdescTy.getShape()[1];
332340
auto height = tdescTy.getShape()[0];
333341
auto array_len = tdescTy.getArrayLength();
334-
auto elemTyByteWidth =
335-
tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
342+
auto elemTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
336343

337344
return verify2dBlockRestriction(op, width, height, array_len,
338-
elemTyByteWidth, transpose, vnni,
345+
elemTyBitWidth, transpose, vnni,
339346
*configParams);
340347
} else {
341348
return loadOp->emitOpError("Invalid 2d block load parameters!\n");
@@ -365,11 +372,10 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
365372
auto width = tdescTy.getShape()[1];
366373
auto height = tdescTy.getShape()[0];
367374
auto array_len = tdescTy.getArrayLength();
368-
auto elemTyByteWidth =
369-
tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
375+
auto elemTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
370376

371377
return verify2dBlockRestriction(op, width, height, array_len,
372-
elemTyByteWidth, transpose, vnni,
378+
elemTyBitWidth, transpose, vnni,
373379
*configParams, false);
374380
} else {
375381
return storeOp->emitOpError()
@@ -395,11 +401,10 @@ mlir::LogicalResult XeuArchInterface::isLegalPrefetch2dOp(mlir::Operation *op) {
395401
auto width = tdescTy.getShape()[1];
396402
auto height = tdescTy.getShape()[0];
397403
auto array_len = tdescTy.getArrayLength();
398-
auto elemTyByteWidth =
399-
tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
404+
auto elemTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
400405

401406
return verify2dPrefetchRestriction(op, width, height, array_len,
402-
elemTyByteWidth, *configParams);
407+
elemTyBitWidth, *configParams);
403408
} else {
404409
return prefetchOp->emitOpError()
405410
<< "Invalid 2d block load parameters for prefetch operation!\n";

0 commit comments

Comments
 (0)