17
17
18
18
namespace imex {
19
19
20
+ static int getInMemoryBitWidth (int elemTyBitWidth) {
21
+ if (elemTyBitWidth == 19 )
22
+ return 32 ; // TF32 is stored in 32 bits;
23
+ // TODO: add support for other loosely packed types
24
+ return elemTyBitWidth;
25
+ }
26
+
20
27
// / Checks Given A,B, C, D Matrix Data types to HW supported configs and
21
28
// / verifies HW restrictions for supported combinations.
22
29
mlir::LogicalResult XePVCuArch::checkSupportedDpasTypes (mlir::Operation *op,
@@ -256,13 +263,6 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
256
263
return mlir::success ();
257
264
}
258
265
259
- static int getInMemoryBitWidth (int elemTyBitWidth) {
260
- if (elemTyBitWidth == 19 )
261
- return 32 ; // TF32 is stored in 32 bits;
262
- // TODO: add support for other loosely packed types
263
- return elemTyBitWidth;
264
- }
265
-
266
266
mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction (
267
267
mlir::Operation *op, int width, int height, int array_len,
268
268
int elemTyBitWidth, bool transpose, bool vnni,
@@ -314,13 +314,12 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
314
314
315
315
if (auto loadOp = llvm::dyn_cast<mlir::xegpu::LoadNdOp>(op)) {
316
316
auto tdescTy = loadOp.getTensorDescType ();
317
+ auto elemTyBitWidth = tdescTy.getElementTypeBitWidth ();
317
318
318
319
// TODO: need more thinking on SLM
319
320
if (tdescTy.getMemorySpace () == mlir::xegpu::MemorySpace::SLM)
320
321
return mlir::success ();
321
322
322
- int elementSize = loadOp.getTensorDescType ().getElementTypeBitWidth ();
323
-
324
323
LoadStore2DConfig loadParams;
325
324
bool vnni = loadOp.getPacked ().value_or (false );
326
325
bool transpose =
@@ -333,7 +332,7 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
333
332
}
334
333
335
334
mlir::FailureOr<LoadStore2DConfig> configParams =
336
- this ->get2DLoadConfig (op, elementSize , vnni, transpose);
335
+ this ->get2DLoadConfig (op, elemTyBitWidth , vnni, transpose);
337
336
if (mlir::succeeded (configParams)) {
338
337
339
338
auto width = tdescTy.getShape ()[1 ];
@@ -355,7 +354,7 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
355
354
356
355
if (auto storeOp = llvm::dyn_cast<mlir::xegpu::StoreNdOp>(op)) {
357
356
auto tdescTy = storeOp.getTensorDescType ();
358
- int elementSize = tdescTy.getElementTypeBitWidth ();
357
+ auto elemTyBitWidth = tdescTy.getElementTypeBitWidth ();
359
358
360
359
// TODO: need more thinking on SLM
361
360
if (tdescTy.getMemorySpace () == mlir::xegpu::MemorySpace::SLM)
@@ -366,21 +365,20 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
366
365
bool transpose = false ;
367
366
368
367
mlir::FailureOr<LoadStore2DConfig> configParams =
369
- this ->get2DStoreConfig (elementSize );
368
+ this ->get2DStoreConfig (elemTyBitWidth );
370
369
if (mlir::succeeded (configParams)) {
371
370
372
371
auto width = tdescTy.getShape ()[1 ];
373
372
auto height = tdescTy.getShape ()[0 ];
374
373
auto array_len = tdescTy.getArrayLength ();
375
- auto elemTyBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
376
374
377
375
return verify2dBlockRestriction (op, width, height, array_len,
378
376
elemTyBitWidth, transpose, vnni,
379
377
*configParams, false );
380
378
} else {
381
379
return storeOp->emitOpError ()
382
380
<< " unsupported data sizes for 2d block store. "
383
- << " Given element data size: d" << elementSize ;
381
+ << " Given element data size: d" << elemTyBitWidth ;
384
382
}
385
383
}
386
384
@@ -391,17 +389,15 @@ mlir::LogicalResult XeuArchInterface::isLegalPrefetch2dOp(mlir::Operation *op) {
391
389
392
390
if (auto prefetchOp = llvm::dyn_cast<mlir::xegpu::PrefetchNdOp>(op)) {
393
391
auto tdescTy = prefetchOp.getTensorDescType ();
394
-
395
- int elementSize = prefetchOp.getTensorDescType ().getElementTypeBitWidth ();
392
+ auto elemTyBitWidth = tdescTy.getElementTypeBitWidth ();
396
393
397
394
mlir::FailureOr<LoadStore2DConfig> configParams =
398
- this ->get2DPrefetchConfig (op, elementSize );
395
+ this ->get2DPrefetchConfig (op, elemTyBitWidth );
399
396
if (mlir::succeeded (configParams)) {
400
397
401
398
auto width = tdescTy.getShape ()[1 ];
402
399
auto height = tdescTy.getShape ()[0 ];
403
400
auto array_len = tdescTy.getArrayLength ();
404
- auto elemTyBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
405
401
406
402
return verify2dPrefetchRestriction (op, width, height, array_len,
407
403
elemTyBitWidth, *configParams);
0 commit comments