1212#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1313#include " mlir/Dialect/XeGPU/Transforms/Passes.h"
1414#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
15+ #include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1516#include " mlir/Transforms/DialectConversion.h"
1617#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
18+ #include < optional>
1719
1820namespace mlir {
1921namespace xegpu {
@@ -29,20 +31,36 @@ using namespace mlir;
2931
3032namespace {
3133
32- class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
34+ static std::optional<SmallVector<int64_t >>
35+ get2DLaneData (xegpu::TensorDescType tdescType) {
36+ auto layout = tdescType.getLayoutAttr ();
37+ if (!layout)
38+ return std::nullopt ;
39+ auto laneData = layout.getEffectiveLaneDataAsInt ();
40+ if (laneData.size () != 2 )
41+ return std::nullopt ;
42+ return laneData;
43+ }
44+
45+ class XeGPUCreateNdDescOpPattern final
46+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
3347public:
34- using OpConversionPattern<xegpu::LoadNdOp >::OpConversionPattern;
48+ using OpConversionPattern<xegpu::CreateNdDescOp >::OpConversionPattern;
3549 LogicalResult
36- matchAndRewrite (xegpu::LoadNdOp loadOp , OpAdaptor adaptor,
50+ matchAndRewrite (xegpu::CreateNdDescOp createNdOp , OpAdaptor adaptor,
3751 ConversionPatternRewriter &rewriter) const override {
52+ auto tdescTy = createNdOp.getType ();
53+ auto convertType = this ->getTypeConverter ()->convertType (tdescTy);
54+ if (convertType == tdescTy)
55+ return failure ();
3856 return success ();
3957 }
4058};
4159} // namespace
4260
4361void xegpu::populateXeGPUOptimizeTransposePatterns (
4462 RewritePatternSet &patterns) {
45- patterns.add <XeGPULoadNdPattern >(patterns.getContext ());
63+ patterns.add <XeGPUCreateNdDescOpPattern >(patterns.getContext ());
4664}
4765
4866namespace {
@@ -55,6 +73,42 @@ struct XeGPUOptimizeTransposePass final
5573 TypeConverter converter;
5674 RewritePatternSet patterns (&context);
5775 ConversionTarget target (context);
76+
77+ target.addDynamicallyLegalOp <xegpu::CreateNdDescOp>(
78+ [](xegpu::CreateNdDescOp createNdOp) {
79+ auto optionalLaneData = get2DLaneData (createNdOp.getType ());
80+ if (!optionalLaneData)
81+ return true ;
82+ auto laneData = optionalLaneData.value ();
83+ return laneData[0 ] != 1 || laneData[1 ] == 1 ;
84+ });
85+
86+ converter.addConversion ([](xegpu::TensorDescType tdescType) {
87+ auto optionalLaneData = get2DLaneData (tdescType);
88+ if (!optionalLaneData)
89+ return tdescType;
90+ auto laneData = optionalLaneData.value ();
91+ int64_t innerLaneData = laneData[1 ];
92+ if (laneData[0 ] == 1 && innerLaneData != 1 ) {
93+ int elementTyBitwidth =
94+ tdescType.getElementType ().getIntOrFloatBitWidth ();
95+ assert (elementTyBitwidth < 32 &&
96+ " Expected element type bitwidth < 32 with laneData[1] != 1" );
97+ SmallVector<int64_t > newShape (tdescType.getShape ());
98+ newShape.back () = newShape.back () / innerLaneData;
99+ Type newElemTy = IntegerType::get (tdescType.getContext (),
100+ elementTyBitwidth * innerLaneData);
101+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get (
102+ tdescType.getContext (),
103+ tdescType.getLayoutAttr ().getLaneLayout ().asArrayRef (), {1 , 1 });
104+ return xegpu::TensorDescType::get (
105+ newShape, newElemTy, tdescType.getArrayLength (),
106+ tdescType.getBoundaryCheck (), tdescType.getMemorySpace (),
107+ newLayout);
108+ }
109+ return tdescType;
110+ });
111+
58112 scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
59113 target);
60114 xegpu::populateXeGPUOptimizeTransposePatterns (patterns);
0 commit comments