@@ -256,9 +256,16 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
256
256
return mlir::success ();
257
257
}
258
258
259
+ static int getInMemoryBitWidth (int elemTyBitWidth) {
260
+ if (elemTyBitWidth == 19 )
261
+ return 32 ; // TF32 is stored in 32 bits;
262
+ // TODO: add support for other loosely packed types
263
+ return elemTyBitWidth;
264
+ }
265
+
259
266
mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction (
260
267
mlir::Operation *op, int width, int height, int array_len,
261
- int elemTyByteWidth , bool transpose, bool vnni,
268
+ int elemTyBitWidth , bool transpose, bool vnni,
262
269
LoadStore2DConfig configParams, bool isLoad) {
263
270
264
271
if (!llvm::isPowerOf2_32 (array_len))
@@ -271,15 +278,15 @@ mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
271
278
272
279
if ((width < configParams.blockWidth .min ||
273
280
width > configParams.blockWidth .max ||
274
- (width * elemTyByteWidth ) % 4 != 0 ))
281
+ (width * getInMemoryBitWidth (elemTyBitWidth) / 8 ) % 4 != 0 ))
275
282
return op->emitOpError ()
276
283
<< " Invalid width size for 2D block load. "
277
284
<< " The specification expects the value to "
278
285
<< " be in range [" << configParams.blockWidth .min << " , "
279
286
<< configParams.blockWidth .max << " ], and "
280
287
<< " the total data size (width * elemTyBytes) to be multiple of 4. "
281
- << " Given width: " << width
282
- << " and data size: " << width * elemTyByteWidth ;
288
+ << " Given width: " << width << " and data size: "
289
+ << width * getInMemoryBitWidth (elemTyBitWidth) / 8 ;
283
290
284
291
if (height < configParams.blockHeight .min ||
285
292
height > configParams.blockHeight .max )
@@ -288,7 +295,8 @@ mlir::LogicalResult XeuArchInterface::verify2dBlockRestriction(
288
295
<< " be in range [" << configParams.blockHeight .min
289
296
<< " , " << configParams.blockHeight .max << " ]." ;
290
297
291
- int GRFSize = width * height * array_len * elemTyByteWidth;
298
+ int GRFSize =
299
+ width * height * array_len * getInMemoryBitWidth (elemTyBitWidth) / 8 ;
292
300
int supportedSize =
293
301
isLoad ? configParams.GRFDataSize .load : configParams.GRFDataSize .store ;
294
302
@@ -331,11 +339,10 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
331
339
auto width = tdescTy.getShape ()[1 ];
332
340
auto height = tdescTy.getShape ()[0 ];
333
341
auto array_len = tdescTy.getArrayLength ();
334
- auto elemTyByteWidth =
335
- tdescTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
342
+ auto elemTyBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
336
343
337
344
return verify2dBlockRestriction (op, width, height, array_len,
338
- elemTyByteWidth , transpose, vnni,
345
+ elemTyBitWidth , transpose, vnni,
339
346
*configParams);
340
347
} else {
341
348
return loadOp->emitOpError (" Invalid 2d block load parameters!\n " );
@@ -365,11 +372,10 @@ mlir::LogicalResult XeuArchInterface::isLegalStore2dOp(mlir::Operation *op) {
365
372
auto width = tdescTy.getShape ()[1 ];
366
373
auto height = tdescTy.getShape ()[0 ];
367
374
auto array_len = tdescTy.getArrayLength ();
368
- auto elemTyByteWidth =
369
- tdescTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
375
+ auto elemTyBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
370
376
371
377
return verify2dBlockRestriction (op, width, height, array_len,
372
- elemTyByteWidth , transpose, vnni,
378
+ elemTyBitWidth , transpose, vnni,
373
379
*configParams, false );
374
380
} else {
375
381
return storeOp->emitOpError ()
@@ -395,11 +401,10 @@ mlir::LogicalResult XeuArchInterface::isLegalPrefetch2dOp(mlir::Operation *op) {
395
401
auto width = tdescTy.getShape ()[1 ];
396
402
auto height = tdescTy.getShape ()[0 ];
397
403
auto array_len = tdescTy.getArrayLength ();
398
- auto elemTyByteWidth =
399
- tdescTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
404
+ auto elemTyBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
400
405
401
406
return verify2dPrefetchRestriction (op, width, height, array_len,
402
- elemTyByteWidth , *configParams);
407
+ elemTyBitWidth , *configParams);
403
408
} else {
404
409
return prefetchOp->emitOpError ()
405
410
<< " Invalid 2d block load parameters for prefetch operation!\n " ;
0 commit comments