@@ -167,8 +167,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
167
167
mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes,
168
168
mlir::DenseI64ArrayAttr outBlkSizes,
169
169
llvm::ArrayRef<int64_t > inGrids,
170
- llvm::ArrayRef<int64_t > outGrids, bool isVnniFormat = false ,
171
- bool isForDPASB = false ) {
170
+ llvm::ArrayRef<int64_t > outGrids) {
172
171
173
172
// handle based on the dim0, and save results into intermediates
174
173
llvm::SmallVector<mlir::Value> intermediates (outGrids[0 ] * inGrids[1 ]);
@@ -190,18 +189,11 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
190
189
}
191
190
}
192
191
}
193
- } else { // do extract on dim0 using vector::ExtractStridedSliceOp
192
+ } else {
193
+ // do extract on dim0 using vector::ExtractStridedSliceOp
194
194
// intermediates.resize(outGrids[0] * inGrids[1]);
195
195
llvm::SmallVector<int64_t > blkSizes ({outBlkSizes[0 ], inBlkSizes[1 ]});
196
- // if the vnni transform applied, vector shape
197
- // and offset need to be adjusted accordingly.
198
- if (isVnniFormat) {
199
- auto vnniAxis = isForDPASB ? 0 : 1 ;
200
- auto factor = mlir::cast<mlir::VectorType>(inputs.front ().getType ())
201
- .getShape ()
202
- .back ();
203
- blkSizes[vnniAxis] /= factor;
204
- }
196
+
205
197
// each vector will be horizonally cut into `nums` subvectors
206
198
auto nums = outGrids[0 ] / inGrids[0 ];
207
199
llvm::SmallVector<int64_t > strides ({1 , 1 });
@@ -244,15 +236,6 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op,
244
236
}
245
237
} else { // doing extract on dim 1
246
238
llvm::SmallVector<int64_t > blkSizes ({outBlkSizes[0 ], outBlkSizes[1 ]});
247
- // if vnni transform applied, vector shape
248
- // and offset needs to adjusted accordingly.
249
- if (isVnniFormat) {
250
- auto vnniAxis = isForDPASB ? 0 : 1 ;
251
- auto factor = mlir::cast<mlir::VectorType>(inputs.front ().getType ())
252
- .getShape ()
253
- .back ();
254
- blkSizes[vnniAxis] /= factor;
255
- }
256
239
llvm::SmallVector<int64_t > strides ({1 , 1 });
257
240
auto nums = outGrids[1 ] / interGrids[1 ];
258
241
for (auto i = 0 ; i < interGrids[0 ]; i++) {
@@ -289,14 +272,6 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
289
272
auto inGrids = inTy.getShape ().take_front (2 );
290
273
auto inBlkSizes = op.getInnerBlocksAttr ();
291
274
292
- // specific attention needed for vectors in vnni format,
293
- // which is applied to load for dpas.
294
- auto loadOp = op.getInVec ().getDefiningOp <xetile::LoadTileOp>();
295
- auto elemTy = op.getInVec ().getType ().getElementType ();
296
- bool isDpasB = loadOp && isForDPASB (loadOp);
297
- bool isVnniFormat =
298
- isDpasB && elemTy.isIntOrFloat () && elemTy.getIntOrFloatBitWidth () < 32 ;
299
-
300
275
// the default grids used as outGrids when unpack is not paired with a pack
301
276
int64_t defautlOutGrids[2 ] = {1 , 1 };
302
277
llvm::ArrayRef<int64_t > outGrids;
@@ -313,20 +288,9 @@ class SgTileUnpackOpPattern : public XeOneToNConversion<xetile::TileUnpackOp> {
313
288
outBlkSizes = mlir::DenseI64ArrayAttr::get (ctx, outTy.getShape ());
314
289
}
315
290
316
- // TODO: logically it is to do concat, but the data is in vnni format
317
- // which breaks the concat logic, it transforms concat into stack.
318
- if (isVnniFormat && (inBlkSizes[1 ] < outBlkSizes[1 ])) {
319
- return op->emitOpError (" [Unexpected rare case]: " )
320
- << " It rarly happens that we need to do concat on vnni "
321
- << " transformed vectors (which is 3D instead of 2D). "
322
- << " It is essentially a stack on the 2nd dim, and is "
323
- << " not implemented yet.\n " ;
324
- }
325
-
326
291
rewriter.setInsertionPoint (op);
327
- auto newOps =
328
- lowerUnpackOrPack (rewriter, op, inputs, inBlkSizes, outBlkSizes,
329
- inGrids, outGrids, isVnniFormat, isDpasB);
292
+ auto newOps = lowerUnpackOrPack (rewriter, op, inputs, inBlkSizes,
293
+ outBlkSizes, inGrids, outGrids);
330
294
331
295
if (op->hasOneUse () && packOp) { // lowered Unpack and Pack as pair
332
296
rewriter.replaceOp (packOp, newOps);
@@ -418,19 +382,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
418
382
if (!innerBlocks || innerBlocks.size () != 2 )
419
383
return op.emitOpError (" Missing valid innerBlock for the tile in op." );
420
384
421
- bool hasColMajorTraversal =
422
- tileTy.getOrder ().asArrayRef () == mlir::ArrayRef ({0 , 1 });
423
385
// Need to make a copy, so we can swap values.
424
386
auto innerBlk = llvm::to_vector (innerBlocks.asArrayRef ());
425
- // If order is [1, 0] (source memref is col-major), we need to swap the
426
- // shape and innerBlocks because XeGPU ops only support row-major tile
427
- // creation.
428
- if (hasColMajorTraversal) {
429
- assert (op.isSourceMemRef () && op.sourceMemRefHasStaticShape () &&
430
- " Expecting a static shape source memref." );
431
- std::swap (innerBlk[0 ], innerBlk[1 ]);
432
- std::swap (shape[0 ], shape[1 ]);
433
- }
434
387
435
388
// using array_length for load if dim1 of innerBlocks is smaller than
436
389
// dim1 of shape.
@@ -440,22 +393,6 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
440
393
? getBlockArrayLength (op, elemTy, innerBlk[1 ], shape[1 ])
441
394
: 1 ;
442
395
443
- // If the source memref is col-major we need to convert into a row-major
444
- // because at XeGPU level we only support row-major memrefs. Also
445
- // array_length must be 1 if col-major memrefs are used as the source.
446
- if (hasColMajorTraversal) {
447
- array_length = 1 ;
448
- // create a memref.reinterpret_cast to convert col-major to row-major
449
- auto rowMajorSourceShape =
450
- swapLastTwoElements (op.getSourceMemrefStaticShape ());
451
- auto rowMajorSourceStrides = defaultStrides (rowMajorSourceShape);
452
- int64_t rowMajorSourceOffset = 0 ;
453
- auto newMemRefTy = mlir::MemRefType::get (rowMajorSourceShape, elemTy);
454
- source = rewriter.create <mlir::memref::ReinterpretCastOp>(
455
- loc, newMemRefTy, source, rowMajorSourceOffset, rowMajorSourceShape,
456
- rowMajorSourceStrides);
457
- }
458
-
459
396
auto width = array_length * innerBlk[1 ];
460
397
461
398
llvm::SmallVector<int64_t , 2 > blocks (
@@ -476,8 +413,6 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
476
413
// For col-major memref initial offsets need to be swapped.
477
414
auto offsetsY = offsets.pop_back_val ();
478
415
auto offsetsX = offsets.pop_back_val ();
479
- if (hasColMajorTraversal)
480
- std::swap (offsetsX, offsetsY);
481
416
482
417
auto tDescTy = mlir::xegpu::TensorDescType::get (
483
418
innerBlk, elemTy, array_length, true /* boundary_check*/ , memoryScope);
@@ -513,12 +448,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
513
448
auto createNdOp = rewriter.create <mlir::xegpu::CreateNdDescOp>(
514
449
op.getLoc (), tDescTy /* resultTy*/ , MemRefTypedSource /* source*/ ,
515
450
tDescOffsets /* offsets*/ );
516
- // col-major source memref requires creating the tiles in transposed
517
- // order
518
- if (hasColMajorTraversal)
519
- xegpuOps[blocks[0 ] * j + i] = createNdOp;
520
- else
521
- xegpuOps[blocks[1 ] * i + j] = createNdOp;
451
+
452
+ xegpuOps[blocks[1 ] * i + j] = createNdOp;
522
453
} else {
523
454
return mlir::failure ();
524
455
}
@@ -620,55 +551,21 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
620
551
auto sources = adaptor.getSource ();
621
552
622
553
auto ctx = op.getContext ();
623
- auto L1 = mlir::xegpu::CachePolicyAttr::get (
624
- ctx, mlir::xegpu::CachePolicy::CACHED);
625
- auto L2 = mlir::xegpu::CachePolicyAttr::get (
626
- ctx, mlir::xegpu::CachePolicy::CACHED);
627
- auto L3 = mlir::xegpu::CachePolicyAttr::get (
628
- ctx, mlir::xegpu::CachePolicy::CACHED);
629
-
630
- mlir::UnitAttr vnniAttr = nullptr ;
631
- mlir::IntegerAttr transposeBitWidthAttr;
632
- // TODO: move these two into architecture abstracture in future.
633
- const int SIMD_WIDTH_IN_BITS = 32 ;
634
- int factor = SIMD_WIDTH_IN_BITS / elemTy.getIntOrFloatBitWidth ();
635
- // TODO: use uArch for this?
636
- auto isLowPrecision = [](unsigned int width) -> bool {
637
- bool isPowerOf2 = (width & (width - 1 )) == 0 ;
638
- return isPowerOf2 & (width < 32 ) & (width > 1 );
554
+
555
+ auto getDefaultCachePolicy = [&]() {
556
+ return mlir::xegpu::CachePolicyAttr::get (
557
+ ctx, mlir::xegpu::CachePolicy::CACHED);
639
558
};
640
- // vnni can only be applied when the blockSZ[0] >= factor
641
- // for shape, e.g., 1xN, vnni cannot be applied, since no
642
- // vnni transform available)
643
- if (isForDPASB (op) && factor > 1 && blockSZ[0 ] >= factor)
644
- vnniAttr = mlir::UnitAttr::get (ctx);
645
559
646
- mlir::DenseI64ArrayAttr transposeAttr;
647
- auto srcOrder = tileTy.getOrder ();
648
- if (srcOrder.asArrayRef () == mlir::ArrayRef ({1 , 0 })) {
649
- // Nothing to do
650
- } else if (srcOrder.asArrayRef () == mlir::ArrayRef ({0 , 1 })) {
651
- auto elemWidth = elemTy.getIntOrFloatBitWidth ();
652
- if (elemWidth == 32 ) {
653
- transposeAttr = rewriter.getDenseI64ArrayAttr ({1 , 0 });
654
- } else if (isLowPrecision (elemWidth) && vnniAttr) {
655
- transposeAttr = rewriter.getDenseI64ArrayAttr ({1 , 0 });
656
- transposeBitWidthAttr = rewriter.getI32IntegerAttr (32 );
657
- vnniAttr = nullptr ;
658
- } else {
659
- return ((mlir::PatternRewriter &)rewriter)
660
- .notifyMatchFailure (op, " Unsupported element type for transpose" );
661
- }
662
- } else {
663
- return ((mlir::PatternRewriter &)rewriter)
664
- .notifyMatchFailure (op, " Unsupported order" );
665
- }
560
+ auto L1 = getDefaultCachePolicy ();
561
+ auto L2 = getDefaultCachePolicy ();
562
+ auto L3 = getDefaultCachePolicy ();
666
563
667
- // vnni and transpose are not available for SLM memory scope.
668
- if (tileTy. getMemoryScopeAsInt () == 3 ) {
669
- vnniAttr = nullptr ;
670
- transposeBitWidthAttr = nullptr ;
671
- }
564
+ // The tile is in col-major order, which should be canonicalized to
565
+ // row-major in canonicalization pass.
566
+ auto srcOrder = tileTy. getOrder () ;
567
+ if (srcOrder. asArrayRef () != mlir::ArrayRef ({ 1 , 0 }))
568
+ return mlir::failure ();
672
569
673
570
rewriter.setInsertionPoint (op);
674
571
llvm::SmallVector<::mlir::Value> xegpuOps;
@@ -678,22 +575,12 @@ struct SgLoadTileOpPattern : public XeOneToNConversion<xetile::LoadTileOp> {
678
575
auto shape = tdescTy.getShape ().vec ();
679
576
auto array_length = tdescTy.getArrayLength ();
680
577
681
- if (transposeAttr)
682
- std::swap (shape[0 ], shape[1 ]);
683
-
684
- if (vnniAttr || transposeBitWidthAttr) {
685
- const int axis = 0 ;
686
- shape[axis] /= factor;
687
- shape.push_back (factor);
688
- }
689
-
690
578
if (array_length != 1 )
691
579
shape.insert (shape.begin (), array_length);
692
580
693
581
auto vectorTy = mlir::VectorType::get (shape, elemTy);
694
582
auto ldOp = rewriter.create <mlir::xegpu::LoadNdOp>(
695
- op.getLoc (), vectorTy, src, vnniAttr, transposeAttr,
696
- transposeBitWidthAttr, L1, L2, L3);
583
+ op.getLoc (), vectorTy, src, nullptr , nullptr , nullptr , L1, L2, L3);
697
584
if (array_length == 1 ) {
698
585
xegpuOps.push_back (ldOp);
699
586
} else {
0 commit comments