99#include " mlir/Dialect/Arith/IR/Arith.h"
1010#include " mlir/Dialect/MemRef/IR/MemRef.h"
1111#include " mlir/Dialect/SCF/Transforms/Patterns.h"
12+ #include " mlir/Dialect/Utils/IndexingUtils.h"
1213#include " mlir/Dialect/Utils/StaticValueUtils.h"
1314#include " mlir/Dialect/Vector/IR/VectorOps.h"
1415#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1516#include " mlir/Dialect/XeGPU/Transforms/Passes.h"
1617#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
1718#include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
19+ #include " mlir/IR/BuiltinAttributeInterfaces.h"
20+ #include " mlir/IR/BuiltinTypes.h"
1821#include " mlir/IR/OpDefinition.h"
1922#include " mlir/IR/Types.h"
2023#include " mlir/IR/Value.h"
2124#include " mlir/Support/LLVM.h"
2225#include " mlir/Transforms/DialectConversion.h"
2326#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27+ #include " llvm/ADT/SmallVector.h"
28+ #include < algorithm>
2429#include < optional>
2530
2631namespace mlir {
@@ -35,12 +40,14 @@ namespace xegpu {
3540
3641using namespace mlir ;
3742
38- struct TransposableBlockRange {
39- int minWidth, maxWidth, minHeight, maxHeight;
43+ namespace {
44+
45+ struct Allowed2DShapeRange {
46+ int64_t minWidth, maxWidth, minHeight, maxHeight;
4047};
4148
4249// TODO: Use uArch to get supported block ranges.
43- static TransposableBlockRange getBlockRange (int bitWidth) {
50+ static Allowed2DShapeRange getTransposableBlockRange (int bitWidth) {
4451 switch (bitWidth) {
4552 case 32 :
4653 return {/* *min width**/ 1 , /* *max width**/ 8 , /* *min height**/ 1 ,
@@ -50,10 +57,8 @@ static TransposableBlockRange getBlockRange(int bitWidth) {
5057 }
5158}
5259
53- namespace {
54-
5560static std::optional<SmallVector<int64_t >>
56- get2DLaneData (xegpu::TensorDescType tdescType) {
61+ getMaybeLaneData (xegpu::TensorDescType tdescType) {
5762 auto layout = tdescType.getLayoutAttr ();
5863 if (!layout)
5964 return std::nullopt ;
@@ -63,44 +68,131 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
6368 return laneData;
6469}
6570
71+ static std::optional<SmallVector<int64_t >>
72+ getMaybeLaneLayout (xegpu::TensorDescType tdescType) {
73+ auto layout = tdescType.getLayoutAttr ();
74+ if (!layout)
75+ return std::nullopt ;
76+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt ();
77+ if (laneLayout.size () != 2 )
78+ return std::nullopt ;
79+ return laneLayout;
80+ }
81+
82+ // A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
83+ // lane[1] == 1), but inner lane data is not equal to [1, 1].
84+ static bool hasInvalidTranposeLayout (xegpu::TensorDescType tdescType) {
85+ // If the dtype is greater or equal to 32 bits, layout must be valid.
86+ int elementTyBitwidth = tdescType.getElementType ().getIntOrFloatBitWidth ();
87+ if (elementTyBitwidth >= 32 )
88+ return false ;
89+ auto maybeLaneLayout = getMaybeLaneLayout (tdescType);
90+ auto maybeLaneData = getMaybeLaneData (tdescType);
91+ if (!maybeLaneData || !maybeLaneLayout)
92+ return false ;
93+ auto laneData = maybeLaneData.value ();
94+ auto laneLayout = maybeLaneLayout.value ();
95+ if (laneLayout[0 ] == 1 || laneLayout[1 ] != 1 )
96+ return false ;
97+ if (laneData[0 ] != 1 || laneData[1 ] == 1 )
98+ return false ;
99+ return true ;
100+ }
101+
66102static xegpu::TensorDescType
67- getModifiedTensorDescType (xegpu::TensorDescType tdescType) {
68- auto optionalLaneData = get2DLaneData (tdescType);
69- if (!optionalLaneData)
103+ tryConvertToTransposable (xegpu::TensorDescType tdescType) {
104+ if (!hasInvalidTranposeLayout (tdescType))
70105 return tdescType;
71- auto laneData = optionalLaneData .value ();
106+ auto laneData = getMaybeLaneData (tdescType) .value ();
72107 int64_t innerLaneData = laneData[1 ];
73- if (laneData[0 ] == 1 && innerLaneData != 1 ) {
74- int elementTyBitwidth = tdescType.getElementType ().getIntOrFloatBitWidth ();
75- assert (elementTyBitwidth < 32 &&
76- " Expected element type bitwidth < 32 with laneData[1] != 1" );
77- SmallVector<int64_t > newShape (tdescType.getShape ());
78- newShape.back () = newShape.back () / innerLaneData;
79- Type newElemTy = IntegerType::get (tdescType.getContext (),
80- elementTyBitwidth * innerLaneData);
81- xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get (
82- tdescType.getContext (),
83- tdescType.getLayoutAttr ().getLaneLayout ().asArrayRef (), {1 , 1 });
84- return xegpu::TensorDescType::get (
85- newShape, newElemTy, tdescType.getArrayLength (),
86- tdescType.getBoundaryCheck (), tdescType.getMemorySpace (), newLayout);
87- }
88- return tdescType;
108+ int elementTyBitwidth = tdescType.getElementType ().getIntOrFloatBitWidth ();
109+ // Required shape is total shape of the vector result that this tensor desc
110+ // must eventually load after adjusting for the new bitwidth and array
111+ // length.
112+ SmallVector<int64_t > requiredShape (tdescType.getShape ());
113+ requiredShape.back () =
114+ requiredShape.back () * tdescType.getArrayLength () / innerLaneData;
115+ int newBitWidth = elementTyBitwidth * innerLaneData;
116+ Type newElemTy = IntegerType::get (tdescType.getContext (), newBitWidth);
117+ // Supported shape is the max transpose shape that can be supported by
118+ // hardware that is less than or equal to required shape.
119+ auto supportedHeight = std::min (
120+ requiredShape[0 ], getTransposableBlockRange (newBitWidth).maxHeight );
121+ auto supportedWidth = std::min (
122+ requiredShape[1 ], getTransposableBlockRange (newBitWidth).maxWidth );
123+ SmallVector<int64_t > supportedShape = {supportedHeight, supportedWidth};
124+
125+ // Required shape must be multiple of supported shape. Otherwise, we can not
126+ // optimize it.
127+ // TODO: Supported shape can be adjusted to handle non-multiple cases.
128+ if (requiredShape[0 ] % supportedShape[0 ] != 0 ||
129+ requiredShape[1 ] % supportedShape[1 ] != 0 )
130+ return tdescType;
131+
132+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get (
133+ tdescType.getContext (),
134+ tdescType.getLayoutAttr ().getLaneLayout ().asArrayRef (), {1 , 1 });
135+ // Array length can not be larger than 1 for transpose case.
136+ return xegpu::TensorDescType::get (
137+ supportedShape, newElemTy, /* *array length**/ 1 ,
138+ tdescType.getBoundaryCheck (), tdescType.getMemorySpace (), newLayout);
139+ }
140+
141+ static Value createConstantIndex (ConversionPatternRewriter &rewriter,
142+ Location loc, int64_t value) {
143+ return arith::ConstantIndexOp::create (rewriter, loc, value).getResult ();
89144}
90145
91146static Value convertToValue (ConversionPatternRewriter &rewriter, Location loc,
92147 OpFoldResult ofr) {
93148 std::optional<int64_t > mayBeInt = getConstantIntValue (ofr);
94149 if (mayBeInt)
95- return arith::ConstantIndexOp::create (rewriter, loc, *mayBeInt);
150+ return createConstantIndex (rewriter, loc, *mayBeInt);
96151 return llvm::cast<Value>(ofr);
97152}
98153
99154static Value divideByConstant (ConversionPatternRewriter &rewriter, Location loc,
100155 Value val, int64_t constant) {
101- auto constantOp = arith::ConstantIndexOp::create (rewriter, loc, constant);
102- return arith::DivUIOp::create (rewriter, loc, val, constantOp.getResult ())
103- .getResult ();
156+ auto constantOp = createConstantIndex (rewriter, loc, constant);
157+ return arith::DivUIOp::create (rewriter, loc, val, constantOp).getResult ();
158+ }
159+
160+ static Value generateLoads (ConversionPatternRewriter &rewriter,
161+ TypedValue<VectorType> data,
162+ SmallVector<int64_t > &shapeRatio,
163+ SmallVector<OpFoldResult> offsets,
164+ SmallVector<int64_t > &supportedShape,
165+ TypedValue<xegpu::TensorDescType> newTensorDesc,
166+ xegpu::LoadNdOp origLoadOp) {
167+ Location loc = data.getLoc ();
168+ assert (offsets.size () >= 2 && " Expecting at least 2 offsets for 2D LoadNdOp" );
169+ Value offsetX = convertToValue (rewriter, loc, offsets[offsets.size () - 2 ]);
170+ Value offsetY = convertToValue (rewriter, loc, offsets[offsets.size () - 1 ]);
171+ for (int64_t h = 0 ; h < shapeRatio[0 ]; ++h) {
172+ for (int64_t w = 0 ; w < shapeRatio[1 ]; ++w) {
173+ int64_t localOffsetX = h * supportedShape[0 ];
174+ int64_t localOffsetY = w * supportedShape[1 ];
175+ Value loadOffsetX = arith::AddIOp::create (
176+ rewriter, loc, offsetX,
177+ createConstantIndex (rewriter, loc, localOffsetX));
178+ Value loadOffsetY = arith::AddIOp::create (
179+ rewriter, loc, offsetY,
180+ createConstantIndex (rewriter, loc, localOffsetY));
181+ auto loadOp = xegpu::LoadNdOp::create (
182+ rewriter, loc,
183+ VectorType::get (supportedShape, data.getType ().getElementType ()),
184+ newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
185+ origLoadOp.getPackedAttr (), origLoadOp.getTransposeAttr (),
186+ origLoadOp.getL1HintAttr (), origLoadOp.getL2HintAttr (),
187+ origLoadOp.getL3HintAttr ());
188+ // Insert the loaded block into the right position in data.
189+ data = vector::InsertStridedSliceOp::create (
190+ rewriter, loc, loadOp.getResult (), data,
191+ ArrayRef<int64_t >{localOffsetX, localOffsetY},
192+ ArrayRef<int64_t >{1 , 1 });
193+ }
194+ }
195+ return data;
104196}
105197
106198class XeGPUCreateNdDescOpPattern final
@@ -111,7 +203,7 @@ class XeGPUCreateNdDescOpPattern final
111203 matchAndRewrite (xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
112204 ConversionPatternRewriter &rewriter) const override {
113205 auto tdescTy = createNdOp.getType ();
114- auto convertType = getModifiedTensorDescType (tdescTy);
206+ auto convertType = tryConvertToTransposable (tdescTy);
115207 if (convertType == tdescTy)
116208 return failure ();
117209 auto strides = createNdOp.getMixedStrides ();
@@ -120,7 +212,7 @@ class XeGPUCreateNdDescOpPattern final
120212 if (!maybeConstInnerStride || *maybeConstInnerStride != 1 )
121213 return failure ();
122214 Value source = createNdOp.getSource ();
123- auto optionalLaneData = get2DLaneData (tdescTy);
215+ auto optionalLaneData = getMaybeLaneData (tdescTy);
124216 assert (optionalLaneData && " Expected 2D lane data" );
125217 auto laneData = optionalLaneData.value ();
126218 int64_t innerLaneData = laneData[1 ];
@@ -160,7 +252,6 @@ class XeGPUCreateNdDescOpPattern final
160252 return success ();
161253 }
162254};
163- } // namespace
164255
165256class XeGPULoadNdDescOpPattern final
166257 : public OpConversionPattern<xegpu::LoadNdOp> {
@@ -175,9 +266,8 @@ class XeGPULoadNdDescOpPattern final
175266 if (adaptorType == origTensorDescType)
176267 return failure ();
177268 // Offsets must be adjusted based on innerLaneData.
178- auto optionalLaneData = get2DLaneData (loadNdOp.getTensorDescType ());
179- assert (optionalLaneData && " Expected 2D lane data" );
180- int64_t innerLaneData = optionalLaneData.value ()[1 ];
269+ auto laneData = getMaybeLaneData (loadNdOp.getTensorDescType ()).value ();
270+ int64_t innerLaneData = laneData[1 ];
181271 auto offsets = loadNdOp.getMixedOffsets ();
182272 if (offsets.empty ())
183273 return rewriter.notifyMatchFailure (loadNdOp,
@@ -187,25 +277,82 @@ class XeGPULoadNdDescOpPattern final
187277 rewriter, loadNdOp.getLoc (),
188278 convertToValue (rewriter, loadNdOp.getLoc (), modifiedOffsets.back ()),
189279 innerLaneData);
190- VectorType modifiedType =
191- VectorType::get (adaptorType.getShape (), adaptorType.getElementType ());
192- // Create a new LoadNdOp with modified offsets and type.
193- auto newLoadNdOp = xegpu::LoadNdOp::create (
194- rewriter, loadNdOp->getLoc (), modifiedType, adaptor.getTensorDesc (),
195- modifiedOffsets, loadNdOp.getPackedAttr (), loadNdOp.getTransposeAttr (),
196- loadNdOp.getL1HintAttr (), loadNdOp.getL2HintAttr (),
197- loadNdOp.getL3HintAttr ());
198- // Bitcast back to the original type.
280+ // Get the 2D data shape of this loadNdOp in its original type including
281+ // array length.
282+ SmallVector<int64_t > origDataShape (origTensorDescType.getShape ());
283+ // origDataShape.back() *= origTensorDescType.getArrayLength();
284+ // Adjust the data shape based on innerLaneData.
285+ origDataShape.back () /= innerLaneData;
286+ // HW supported shape is the new tensor desc shape after conversion.
287+ SmallVector<int64_t > hwSupportedShape (adaptorType.getShape ());
288+ // Shape ratio is 2D and, it describes how many blocks need to be loaded in
289+ // HW supported shape to cover the original shape.
290+ auto ratio = computeShapeRatio (origDataShape, hwSupportedShape)
291+ .value (); // ratio must be defined if we reach here.
292+ // Create a zero-initialized vector to hold all loaded blocks.
293+ // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
294+ VectorType origVectorType =
295+ VectorType::get (origDataShape, adaptorType.getElementType ());
296+ Value data;
297+ // Orig data shape is 3D for the array length case.
298+ if (origTensorDescType.getArrayLength () > 1 ) {
299+ SmallVector<int64_t > arrayLenDataShape (origDataShape);
300+ arrayLenDataShape.insert (arrayLenDataShape.begin (),
301+ origTensorDescType.getArrayLength ());
302+ auto arrayLenVecType =
303+ VectorType::get (arrayLenDataShape, adaptorType.getElementType ());
304+ data = arith::ConstantOp::create (rewriter, loadNdOp->getLoc (),
305+ arrayLenVecType,
306+ rewriter.getZeroAttr (arrayLenVecType));
307+ for (int64_t i = 0 ; i < origTensorDescType.getArrayLength (); ++i) {
308+ Value slice = arith::ConstantOp::create (
309+ rewriter, loadNdOp->getLoc (),
310+ VectorType::get (origDataShape, adaptorType.getElementType ()),
311+ rewriter.getZeroAttr (origVectorType));
312+ // Increse the Y offset for each array slice.
313+ Value offsetY = convertToValue (rewriter, loadNdOp->getLoc (),
314+ modifiedOffsets.back ());
315+ modifiedOffsets.back () =
316+ arith::AddIOp::create (rewriter, loadNdOp->getLoc (), offsetY,
317+ createConstantIndex (rewriter,
318+ loadNdOp->getLoc (),
319+ i * origDataShape[1 ]))
320+ .getResult ();
321+ slice = generateLoads (
322+ rewriter, cast<TypedValue<VectorType>>(slice), ratio,
323+ modifiedOffsets, hwSupportedShape,
324+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc ()),
325+ loadNdOp);
326+ // Insert slice to data.
327+ data = vector::InsertOp::create (rewriter, loadNdOp->getLoc (), slice,
328+ data, ArrayRef<int64_t >{i});
329+ }
330+ // Cast back to the original type and replace all uses.
331+ data = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
332+ loadNdOp.getType (), data);
333+ rewriter.replaceOp (loadNdOp, data);
334+ return success ();
335+ }
336+ data = arith::ConstantOp::create (
337+ rewriter, loadNdOp->getLoc (),
338+ VectorType::get (origDataShape, adaptorType.getElementType ()),
339+ rewriter.getZeroAttr (origVectorType));
340+ data = generateLoads (
341+ rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
342+ hwSupportedShape,
343+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc ()),
344+ loadNdOp);
199345 auto castOp = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
200- loadNdOp.getType (), newLoadNdOp );
201- // Cast op must have the same layout as the original LoadNdOp result.
202- xegpu::setDistributeLayoutAttr (
203- castOp->getOpResult (0 ),
204- xegpu::getDistributeLayoutAttr (loadNdOp.getResult ()));
205- rewriter.replaceOp (loadNdOp, castOp. getResult () );
346+ loadNdOp.getType (), data );
347+ // // Cast op must have the same layout as the original LoadNdOp result.
348+ // xegpu::setDistributeLayoutAttr(
349+ // castOp->getOpResult(0),
350+ // xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
351+ rewriter.replaceOp (loadNdOp, castOp);
206352 return success ();
207353 }
208354};
355+ } // namespace
209356
210357void xegpu::populateXeGPUOptimizeTransposePatterns (
211358 RewritePatternSet &patterns) {
@@ -224,23 +371,15 @@ struct XeGPUOptimizeTransposePass final
224371 RewritePatternSet patterns (&context);
225372 ConversionTarget target (context);
226373
227- auto checkValidInnerLaneData =
228- [](std::optional<SmallVector<int64_t >> optionalLaneData) -> bool {
229- if (!optionalLaneData)
230- return true ;
231- auto laneData = optionalLaneData.value ();
232- return laneData[0 ] != 1 || laneData[1 ] == 1 ;
233- };
234-
374+ // CreateNdDescOp and LoadNdOp with invalid transpose layout must be
375+ // converted.
235376 target.addDynamicallyLegalOp <xegpu::CreateNdDescOp>(
236377 [&](xegpu::CreateNdDescOp createNdOp) {
237- auto optionalLaneData = get2DLaneData (createNdOp.getType ());
238- return checkValidInnerLaneData (optionalLaneData);
378+ return !hasInvalidTranposeLayout (createNdOp.getType ());
239379 });
240380 target.addDynamicallyLegalOp <xegpu::LoadNdOp>(
241381 [&](xegpu::LoadNdOp loadNdOp) {
242- auto optionalLaneData = get2DLaneData (loadNdOp.getTensorDescType ());
243- return checkValidInnerLaneData (optionalLaneData);
382+ return !hasInvalidTranposeLayout (loadNdOp.getTensorDescType ());
244383 });
245384 converter.addConversion ([](Type type) { return type; });
246385
0 commit comments