Skip to content

Commit a5a7ab4

Browse files
silee2hsmahesha
andauthored
Cleanup code related ld/st block restriction checks (#1101)
cleanup code related ldst block restriction checks * consistently use `elemTyBitWidth` instead of `elementSize` * move `getInMemoryBitWidth` to top of source file, hence it can be conveniently refered in embargo as well Co-authored-by: Mahesha S <[email protected]>
1 parent 26c711e commit a5a7ab4

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

lib/Utils/XeArch.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
namespace imex {
1919

20+
static int getInMemoryBitWidth(int elemTyBitWidth) {
21+
if (elemTyBitWidth == 19)
22+
return 32; // TF32 is stored in 32 bits;
23+
// TODO: add support for other loosely packed types
24+
return elemTyBitWidth;
25+
}
26+
2027
/// Checks Given A,B, C, D Matrix Data types to HW supported configs and
2128
/// verifies HW restrictions for supported combinations.
2229
mlir::LogicalResult XePVCuArch::checkSupportedDpasTypes(mlir::Operation *op,
@@ -256,13 +263,6 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
256263
return mlir::success();
257264
}
258265

259-
static int getInMemoryBitWidth(int elemTyBitWidth) {
260-
if (elemTyBitWidth == 19)
261-
return 32; // TF32 is stored in 32 bits;
262-
// TODO: add support for other loosely packed types
263-
return elemTyBitWidth;
264-
}
265-
266266
mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
267267
mlir::Operation *op, int width, int height, int array_len,
268268
int elemTyBitWidth, bool transpose, bool vnni,
@@ -314,13 +314,12 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
314314

315315
if (auto loadOp = llvm::dyn_cast<mlir::xegpu::LoadNdOp>(op)) {
316316
auto tdescTy = loadOp.getTensorDescType();
317+
auto elemTyBitWidth = tdescTy.getElementTypeBitWidth();
317318

318319
// TODO: need more thinking on SLM
319320
if (tdescTy.getMemorySpace() == mlir::xegpu::MemorySpace::SLM)
320321
return mlir::success();
321322

322-
int elementSize = loadOp.getTensorDescType().getElementTypeBitWidth();
323-
324323
LoadStore2DConfig loadParams;
325324
bool vnni = loadOp.getPacked().value_or(false);
326325
bool transpose =
@@ -333,7 +332,7 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
333332
}
334333

335334
mlir::FailureOr<LoadStore2DConfig> configParams =
336-
this->get2DLoadConfig(op, elementSize, vnni, transpose);
335+
this->get2DLoadConfig(op, elemTyBitWidth, vnni, transpose);
337336
if (mlir::succeeded(configParams)) {
338337

339338
auto width = tdescTy.getShape()[1];
@@ -355,7 +354,7 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
355354

356355
if (auto storeOp = llvm::dyn_cast<mlir::xegpu::StoreNdOp>(op)) {
357356
auto tdescTy = storeOp.getTensorDescType();
358-
int elementSize = tdescTy.getElementTypeBitWidth();
357+
auto elemTyBitWidth = tdescTy.getElementTypeBitWidth();
359358

360359
// TODO: need more thinking on SLM
361360
if (tdescTy.getMemorySpace() == mlir::xegpu::MemorySpace::SLM)
@@ -366,21 +365,20 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
366365
bool transpose = false;
367366

368367
mlir::FailureOr<LoadStore2DConfig> configParams =
369-
this->get2DStoreConfig(elementSize);
368+
this->get2DStoreConfig(elemTyBitWidth);
370369
if (mlir::succeeded(configParams)) {
371370

372371
auto width = tdescTy.getShape()[1];
373372
auto height = tdescTy.getShape()[0];
374373
auto array_len = tdescTy.getArrayLength();
375-
auto elemTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
376374

377375
return verify2dBlockRestriction(op, width, height, array_len,
378376
elemTyBitWidth, transpose, vnni,
379377
*configParams, false);
380378
} else {
381379
return storeOp->emitOpError()
382380
<< "unsupported data sizes for 2d block store. "
383-
<< "Given element data size: d" << elementSize;
381+
<< "Given element data size: d" << elemTyBitWidth;
384382
}
385383
}
386384

@@ -391,17 +389,15 @@ mlir::LogicalResult XeuArchInterface::isLegalPrefetch2dOp(mlir::Operation *op) {
391389

392390
if (auto prefetchOp = llvm::dyn_cast<mlir::xegpu::PrefetchNdOp>(op)) {
393391
auto tdescTy = prefetchOp.getTensorDescType();
394-
395-
int elementSize = prefetchOp.getTensorDescType().getElementTypeBitWidth();
392+
auto elemTyBitWidth = tdescTy.getElementTypeBitWidth();
396393

397394
mlir::FailureOr<LoadStore2DConfig> configParams =
398-
this->get2DPrefetchConfig(op, elementSize);
395+
this->get2DPrefetchConfig(op, elemTyBitWidth);
399396
if (mlir::succeeded(configParams)) {
400397

401398
auto width = tdescTy.getShape()[1];
402399
auto height = tdescTy.getShape()[0];
403400
auto array_len = tdescTy.getArrayLength();
404-
auto elemTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
405401

406402
return verify2dPrefetchRestriction(op, width, height, array_len,
407403
elemTyBitWidth, *configParams);

0 commit comments

Comments
 (0)