@@ -58,7 +58,7 @@ populateXeTileBlockAligningPatterns(imex::XeTypeConverter &converter,
58
58
mlir::RewritePatternSet &patterns,
59
59
PropagateAnalysis &analysis);
60
60
61
- enum OpType { Prefetch, Load, Store, Elementwise };
61
+ enum OpType { Prefetch, Load, Store, Elementwise, Transpose };
62
62
63
63
// Find the maximum divisible number between minHeight/Width and maxHeight/Width
64
64
// and use that as the inner block sizes.
@@ -170,8 +170,26 @@ getInnerBlockSizes(mlir::Operation *operation, mlir::Type elemTy, int height,
170
170
// TODO: get from uArch?
171
171
int64_t subgroupSize = 16 ;
172
172
173
- return {1 , subgroupSize};
173
+ maxHeight = 1 ;
174
+ minHeight = 1 ;
175
+ maxWidth = subgroupSize;
176
+ minWidth = 1 ;
177
+
178
+ return imex::getInnerBlockHeightWidth (maxHeight, maxWidth, minHeight,
179
+ minWidth, height, width);
180
+ }
181
+
182
+ if (op == OpType::Transpose) {
183
+ // TODO: get from uArch?
184
+ maxHeight = 16 ;
185
+ minHeight = 1 ;
186
+ maxWidth = 16 ;
187
+ minWidth = 1 ;
188
+
189
+ return imex::getInnerBlockHeightWidth (maxHeight, maxWidth, minHeight,
190
+ minWidth, height, width);
174
191
}
192
+
175
193
llvm_unreachable (" Unsupported." );
176
194
return {};
177
195
}
@@ -368,6 +386,70 @@ struct VectorizableOpPattern
368
386
}
369
387
};
370
388
389
+ struct TransposeOpPattern
390
+ : public XeTileConversion<mlir::vector::TransposeOp, TileUsageAnalysis> {
391
+
392
+ using XeTileConversion::XeTileConversion;
393
+
394
+ TransposeOpPattern (mlir::MLIRContext *context,
395
+ imex::XeTypeConverter &converter,
396
+ TileUsageAnalysis &analysis,
397
+ std::shared_ptr<XeuArchInterface> ptruArch)
398
+ : XeTileConversion(context, converter, analysis) {
399
+ this ->uArchInterface = ptruArch;
400
+ }
401
+
402
+ std::shared_ptr<XeuArchInterface> uArchInterface = nullptr ;
403
+
404
+ mlir::LogicalResult
405
+ matchAndRewrite (mlir::vector::TransposeOp op, OpAdaptor adaptor,
406
+ mlir::PatternRewriter &rewriter) const override {
407
+ auto res = op.getResult ();
408
+ auto resType = mlir::cast<mlir::VectorType>(res.getType ());
409
+ if (resType.getRank () != 2 )
410
+ return rewriter.notifyMatchFailure (op, " type is not 2D vector" );
411
+
412
+ auto permutation = op.getPermutation ();
413
+ if (permutation != mlir::ArrayRef<int64_t >({1 , 0 }))
414
+ return rewriter.notifyMatchFailure (op, " Unsupported permutation" );
415
+
416
+ auto shape = resType.getShape ();
417
+ auto blocks = getInnerBlockSizes<Transpose>(
418
+ op, resType.getElementType (), shape[0 ], shape[1 ], this ->uArchInterface );
419
+
420
+ if (blocks.size () != 2 )
421
+ return rewriter.notifyMatchFailure (op, " Invalid inner block sizes" );
422
+
423
+ int64_t inBlocks[2 ] = {blocks[1 ], blocks[0 ]};
424
+
425
+ auto newSrcTy = mlir::VectorType::get (
426
+ {shape[1 ] / blocks[1 ], shape[0 ] / blocks[0 ], blocks[1 ], blocks[0 ]},
427
+ resType.getElementType ());
428
+
429
+ auto newDstTy = mlir::VectorType::get (
430
+ {shape[0 ] / blocks[0 ], shape[1 ] / blocks[1 ], blocks[0 ], blocks[1 ]},
431
+ resType.getElementType ());
432
+
433
+ mlir::Value arg = adaptor.getVector ();
434
+ Location loc = op->getLoc ();
435
+ mlir::Value pack = rewriter.create <xetile::TilePackOp>(
436
+ loc, newSrcTy, arg,
437
+ mlir::DenseI64ArrayAttr::get (getContext (), inBlocks));
438
+
439
+ int64_t newPermutation[4 ] = {1 , 0 , 3 , 2 };
440
+ mlir::Value transpose = rewriter.create <mlir::vector::TransposeOp>(
441
+ loc, newDstTy, pack, newPermutation);
442
+
443
+ mlir::Value unpack = rewriter.create <xetile::TileUnpackOp>(
444
+ loc, resType, transpose,
445
+ mlir::DenseI64ArrayAttr::get (getContext (), blocks));
446
+
447
+ rewriter.replaceOp (op, unpack);
448
+
449
+ return mlir::success ();
450
+ }
451
+ };
452
+
371
453
struct VectorMultiDimReductionOpPattern
372
454
: public XeTileConversion<mlir::vector::MultiDimReductionOp,
373
455
TileUsageAnalysis> {
@@ -873,11 +955,12 @@ struct UpdateTileOffsetOpPattern
873
955
void populateXeTileBlockingPatterns (
874
956
imex::XeTypeConverter &converter, mlir::RewritePatternSet &patterns,
875
957
TileUsageAnalysis &analysis, std::shared_ptr<XeuArchInterface> ptruArch) {
876
- patterns.insert <ArithConstantOpPattern, VectorizableOpPattern,
877
- SCFForOpPattern, SCFYieldOpPattern, InitTileOpPattern,
878
- LoadTileOpPattern, StoreTileOpPattern, TileMMAOpPattern,
879
- UpdateTileOffsetOpPattern, VectorMultiDimReductionOpPattern>(
880
- patterns.getContext (), converter, analysis, ptruArch);
958
+ patterns
959
+ .insert <ArithConstantOpPattern, VectorizableOpPattern, SCFForOpPattern,
960
+ SCFYieldOpPattern, InitTileOpPattern, LoadTileOpPattern,
961
+ StoreTileOpPattern, TileMMAOpPattern, UpdateTileOffsetOpPattern,
962
+ TransposeOpPattern, VectorMultiDimReductionOpPattern>(
963
+ patterns.getContext (), converter, analysis, ptruArch);
881
964
}
882
965
883
966
// Lowers XeTile to blocked layout with high-dim vector
0 commit comments