66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Arith/IR/Arith.h"
910#include " mlir/Dialect/MemRef/IR/MemRef.h"
1011#include " mlir/Dialect/SCF/Transforms/Patterns.h"
12+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
1113#include " mlir/Dialect/Vector/IR/VectorOps.h"
1214#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1315#include " mlir/Dialect/XeGPU/Transforms/Passes.h"
1416#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
1517#include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
18+ #include " mlir/IR/OpDefinition.h"
19+ #include " mlir/IR/Types.h"
20+ #include " mlir/IR/Value.h"
21+ #include " mlir/Support/LLVM.h"
1622#include " mlir/Transforms/DialectConversion.h"
1723#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
1824#include < optional>
@@ -42,6 +48,46 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
4248 return laneData;
4349}
4450
51+ static xegpu::TensorDescType
52+ getModifiedTensorDescType (xegpu::TensorDescType tdescType) {
53+ auto optionalLaneData = get2DLaneData (tdescType);
54+ if (!optionalLaneData)
55+ return tdescType;
56+ auto laneData = optionalLaneData.value ();
57+ int64_t innerLaneData = laneData[1 ];
58+ if (laneData[0 ] == 1 && innerLaneData != 1 ) {
59+ int elementTyBitwidth = tdescType.getElementType ().getIntOrFloatBitWidth ();
60+ assert (elementTyBitwidth < 32 &&
61+ " Expected element type bitwidth < 32 with laneData[1] != 1" );
62+ SmallVector<int64_t > newShape (tdescType.getShape ());
63+ newShape.back () = newShape.back () / innerLaneData;
64+ Type newElemTy = IntegerType::get (tdescType.getContext (),
65+ elementTyBitwidth * innerLaneData);
66+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get (
67+ tdescType.getContext (),
68+ tdescType.getLayoutAttr ().getLaneLayout ().asArrayRef (), {1 , 1 });
69+ return xegpu::TensorDescType::get (
70+ newShape, newElemTy, tdescType.getArrayLength (),
71+ tdescType.getBoundaryCheck (), tdescType.getMemorySpace (), newLayout);
72+ }
73+ return tdescType;
74+ }
75+
76+ static Value convertToValue (ConversionPatternRewriter &rewriter, Location loc,
77+ OpFoldResult ofr) {
78+ std::optional<int64_t > mayBeInt = getConstantIntValue (ofr);
79+ if (mayBeInt)
80+ return arith::ConstantIndexOp::create (rewriter, loc, *mayBeInt);
81+ return llvm::cast<Value>(ofr);
82+ }
83+
84+ static Value divideByConstant (ConversionPatternRewriter &rewriter, Location loc,
85+ Value val, int64_t constant) {
86+ auto constantOp = arith::ConstantIndexOp::create (rewriter, loc, constant);
87+ return arith::DivUIOp::create (rewriter, loc, val, constantOp.getResult ())
88+ .getResult ();
89+ }
90+
4591class XeGPUCreateNdDescOpPattern final
4692 : public OpConversionPattern<xegpu::CreateNdDescOp> {
4793public:
@@ -50,17 +96,106 @@ class XeGPUCreateNdDescOpPattern final
5096 matchAndRewrite (xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
5197 ConversionPatternRewriter &rewriter) const override {
5298 auto tdescTy = createNdOp.getType ();
53- auto convertType = this -> getTypeConverter ()-> convertType (tdescTy);
99+ auto convertType = getModifiedTensorDescType (tdescTy);
54100 if (convertType == tdescTy)
55101 return failure ();
102+ auto strides = createNdOp.getMixedStrides ();
103+ auto maybeConstInnerStride = getConstantIntValue (strides.back ());
104+ // Only row-major memrefs are expected for now.
105+ if (!maybeConstInnerStride || *maybeConstInnerStride != 1 )
106+ return failure ();
107+ Value source = createNdOp.getSource ();
108+ auto optionalLaneData = get2DLaneData (tdescTy);
109+ assert (optionalLaneData && " Expected 2D lane data" );
110+ auto laneData = optionalLaneData.value ();
111+ int64_t innerLaneData = laneData[1 ];
112+ auto memrefType = dyn_cast<MemRefType>(source.getType ());
113+ // Inner dimension of the shape must be adjusted based on innerLaneData.
114+ SmallVector<OpFoldResult> modifiedShape (createNdOp.getMixedSizes ());
115+ modifiedShape.back () = divideByConstant (
116+ rewriter, createNdOp.getLoc (),
117+ convertToValue (rewriter, createNdOp.getLoc (), modifiedShape.back ()),
118+ innerLaneData);
119+ // Similarly, second to last stride must be adjusted.
120+ assert (strides.size () >= 2 &&
121+ " Expected at least 2 strides for CreateNdDescOp" );
122+ SmallVector<OpFoldResult> modifiedStrides (strides);
123+ modifiedStrides[modifiedStrides.size () - 2 ] = divideByConstant (
124+ rewriter, createNdOp.getLoc (),
125+ convertToValue (rewriter, createNdOp.getLoc (),
126+ modifiedStrides[modifiedStrides.size () - 2 ]),
127+ innerLaneData);
128+
129+ // If the source is a static memref, we need to extract the pointer to
130+ // base address.
131+ if (memrefType && memrefType.hasStaticShape ()) {
132+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create (
133+ rewriter, createNdOp.getLoc (), source);
134+ source = arith::IndexCastOp::create (
135+ rewriter, createNdOp.getLoc (),
136+ IntegerType::get (rewriter.getContext (), 64 ),
137+ extractOp.getResult ())
138+ .getResult ();
139+ }
140+ // Create a new CreateNdDescOp with the modified shape and converted type.
141+ auto newCreateNdDescOp = xegpu::CreateNdDescOp::create (
142+ rewriter, createNdOp.getLoc (), convertType, source, modifiedShape,
143+ modifiedStrides);
144+ rewriter.replaceOp (createNdOp, newCreateNdDescOp.getResult ());
56145 return success ();
57146 }
58147};
59148} // namespace
60149
150+ class XeGPULoadNdDescOpPattern final
151+ : public OpConversionPattern<xegpu::LoadNdOp> {
152+ public:
153+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
154+ LogicalResult
155+ matchAndRewrite (xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
156+ ConversionPatternRewriter &rewriter) const override {
157+ auto origTensorDescType = loadNdOp.getTensorDescType ();
158+ auto adaptorType =
159+ cast<xegpu::TensorDescType>(adaptor.getTensorDesc ().getType ());
160+ if (adaptorType == origTensorDescType)
161+ return failure ();
162+ // Offsets must be adjusted based on innerLaneData.
163+ auto optionalLaneData = get2DLaneData (loadNdOp.getTensorDescType ());
164+ assert (optionalLaneData && " Expected 2D lane data" );
165+ int64_t innerLaneData = optionalLaneData.value ()[1 ];
166+ auto offsets = loadNdOp.getMixedOffsets ();
167+ if (offsets.empty ())
168+ return rewriter.notifyMatchFailure (loadNdOp,
169+ " Expecting offsets in LoadNd" );
170+ SmallVector<OpFoldResult> modifiedOffsets (offsets);
171+ modifiedOffsets.back () = divideByConstant (
172+ rewriter, loadNdOp.getLoc (),
173+ convertToValue (rewriter, loadNdOp.getLoc (), modifiedOffsets.back ()),
174+ innerLaneData);
175+ VectorType modifiedType =
176+ VectorType::get (adaptorType.getShape (), adaptorType.getElementType ());
177+ // Create a new LoadNdOp with modified offsets and type.
178+ auto newLoadNdOp = xegpu::LoadNdOp::create (
179+ rewriter, loadNdOp->getLoc (), modifiedType, adaptor.getTensorDesc (),
180+ modifiedOffsets, loadNdOp.getPackedAttr (), loadNdOp.getTransposeAttr (),
181+ loadNdOp.getL1HintAttr (), loadNdOp.getL2HintAttr (),
182+ loadNdOp.getL3HintAttr ());
183+ // Bitcast back to the original type.
184+ auto castOp = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
185+ loadNdOp.getType (), newLoadNdOp);
186+ // Cast op must have the same layout as the original LoadNdOp result.
187+ xegpu::setDistributeLayoutAttr (
188+ castOp->getOpResult (0 ),
189+ xegpu::getDistributeLayoutAttr (loadNdOp.getResult ()));
190+ rewriter.replaceOp (loadNdOp, castOp.getResult ());
191+ return success ();
192+ }
193+ };
194+
61195void xegpu::populateXeGPUOptimizeTransposePatterns (
62196 RewritePatternSet &patterns) {
63- patterns.add <XeGPUCreateNdDescOpPattern>(patterns.getContext ());
197+ patterns.add <XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
198+ patterns.getContext ());
64199}
65200
66201namespace {
@@ -74,41 +209,28 @@ struct XeGPUOptimizeTransposePass final
74209 RewritePatternSet patterns (&context);
75210 ConversionTarget target (context);
76211
212+ auto checkValidInnerLaneData =
213+ [](std::optional<SmallVector<int64_t >> optionalLaneData) -> bool {
214+ if (!optionalLaneData)
215+ return true ;
216+ auto laneData = optionalLaneData.value ();
217+ return laneData[0 ] != 1 || laneData[1 ] == 1 ;
218+ };
219+
77220 target.addDynamicallyLegalOp <xegpu::CreateNdDescOp>(
78- [](xegpu::CreateNdDescOp createNdOp) {
221+ [& ](xegpu::CreateNdDescOp createNdOp) {
79222 auto optionalLaneData = get2DLaneData (createNdOp.getType ());
80- if (!optionalLaneData)
81- return true ;
82- auto laneData = optionalLaneData.value ();
83- return laneData[0 ] != 1 || laneData[1 ] == 1 ;
223+ return checkValidInnerLaneData (optionalLaneData);
84224 });
225+ target.addDynamicallyLegalOp <xegpu::LoadNdOp>(
226+ [&](xegpu::LoadNdOp loadNdOp) {
227+ auto optionalLaneData = get2DLaneData (loadNdOp.getTensorDescType ());
228+ return checkValidInnerLaneData (optionalLaneData);
229+ });
230+ converter.addConversion ([](Type type) { return type; });
85231
86- converter.addConversion ([](xegpu::TensorDescType tdescType) {
87- auto optionalLaneData = get2DLaneData (tdescType);
88- if (!optionalLaneData)
89- return tdescType;
90- auto laneData = optionalLaneData.value ();
91- int64_t innerLaneData = laneData[1 ];
92- if (laneData[0 ] == 1 && innerLaneData != 1 ) {
93- int elementTyBitwidth =
94- tdescType.getElementType ().getIntOrFloatBitWidth ();
95- assert (elementTyBitwidth < 32 &&
96- " Expected element type bitwidth < 32 with laneData[1] != 1" );
97- SmallVector<int64_t > newShape (tdescType.getShape ());
98- newShape.back () = newShape.back () / innerLaneData;
99- Type newElemTy = IntegerType::get (tdescType.getContext (),
100- elementTyBitwidth * innerLaneData);
101- xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get (
102- tdescType.getContext (),
103- tdescType.getLayoutAttr ().getLaneLayout ().asArrayRef (), {1 , 1 });
104- return xegpu::TensorDescType::get (
105- newShape, newElemTy, tdescType.getArrayLength (),
106- tdescType.getBoundaryCheck (), tdescType.getMemorySpace (),
107- newLayout);
108- }
109- return tdescType;
110- });
111-
232+ target.addLegalDialect <arith::ArithDialect, memref::MemRefDialect,
233+ vector::VectorDialect>();
112234 scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
113235 target);
114236 xegpu::populateXeGPUOptimizeTransposePatterns (patterns);
0 commit comments