@@ -79,9 +79,20 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
7979 return laneLayout;
8080}
8181
82+ static bool canBeOptimized (ArrayRef<int64_t > laneLayout,
83+ ArrayRef<int64_t > laneData) {
84+ if (laneLayout.size () != 2 || laneData.size () != 2 )
85+ return false ;
86+ if (laneLayout[0 ] == 1 || laneLayout[1 ] != 1 )
87+ return false ;
88+ if (laneData[0 ] != 1 || laneData[1 ] == 1 )
89+ return false ;
90+ return true ;
91+ }
92+
8293// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
8394// lane[1] == 1), but inner lane data is not equal to [1, 1].
84- static bool hasInvalidTranposeLayout (xegpu::TensorDescType tdescType) {
95+ static bool canBeOptimized (xegpu::TensorDescType tdescType) {
8596 // If the dtype is greater or equal to 32 bits, layout must be valid.
8697 int elementTyBitwidth = tdescType.getElementType ().getIntOrFloatBitWidth ();
8798 if (elementTyBitwidth >= 32 )
@@ -90,18 +101,12 @@ static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
90101 auto maybeLaneData = getMaybeLaneData (tdescType);
91102 if (!maybeLaneData || !maybeLaneLayout)
92103 return false ;
93- auto laneData = maybeLaneData.value ();
94- auto laneLayout = maybeLaneLayout.value ();
95- if (laneLayout[0 ] == 1 || laneLayout[1 ] != 1 )
96- return false ;
97- if (laneData[0 ] != 1 || laneData[1 ] == 1 )
98- return false ;
99- return true ;
104+ return canBeOptimized (*maybeLaneLayout, *maybeLaneData);
100105}
101106
102107static xegpu::TensorDescType
103108tryConvertToTransposable (xegpu::TensorDescType tdescType) {
104- if (!hasInvalidTranposeLayout (tdescType))
109+ if (!canBeOptimized (tdescType))
105110 return tdescType;
106111 auto laneData = getMaybeLaneData (tdescType).value ();
107112 int64_t innerLaneData = laneData[1 ];
@@ -185,11 +190,17 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
185190 origLoadOp.getPackedAttr (), origLoadOp.getTransposeAttr (),
186191 origLoadOp.getL1HintAttr (), origLoadOp.getL2HintAttr (),
187192 origLoadOp.getL3HintAttr ());
193+ // Set the layout for the loadOp.
194+ auto layoutAttr = newTensorDesc.getType ().getLayoutAttr ();
195+ xegpu::setDistributeLayoutAttr (loadOp->getOpResult (0 ), layoutAttr);
188196 // Insert the loaded block into the right position in data.
189- data = vector::InsertStridedSliceOp::create (
197+ auto insertOp = vector::InsertStridedSliceOp::create (
190198 rewriter, loc, loadOp.getResult (), data,
191199 ArrayRef<int64_t >{localOffsetX, localOffsetY},
192200 ArrayRef<int64_t >{1 , 1 });
201+ // InsertOp must have the same layout as newTensorDesc.
202+ xegpu::setDistributeLayoutAttr (insertOp->getOpResult (0 ), layoutAttr);
203+ data = insertOp.getResult ();
193204 }
194205 }
195206 return data;
@@ -288,28 +299,20 @@ class XeGPULoadNdDescOpPattern final
288299 // Shape ratio is 2D and, it describes how many blocks need to be loaded in
289300 // HW supported shape to cover the original shape.
290301 auto ratio = computeShapeRatio (origDataShape, hwSupportedShape)
291- .value (); // ratio must be defined if we reach here.
302+ .value (); // ` ratio` must be defined if we reach here.
292303 // Create a zero-initialized vector to hold all loaded blocks.
293304 // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
294305 VectorType origVectorType =
295306 VectorType::get (origDataShape, adaptorType.getElementType ());
296307 Value data;
297308 // Orig data shape is 3D for the array length case.
298309 if (origTensorDescType.getArrayLength () > 1 ) {
299- // SmallVector<int64_t> arrayLenDataShape(origDataShape);
300- // arrayLenDataShape.insert(arrayLenDataShape.begin(),
301- // origTensorDescType.getArrayLength());
302- // auto arrayLenVecType =
303- // VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304- // auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305- // arrayLenVecType,
306- // rewriter.getZeroAttr(arrayLenVecType));
307310 SmallVector<Value> arraySlices;
308311 for (int64_t i = 0 ; i < origTensorDescType.getArrayLength (); ++i) {
309312 Value slice = arith::ConstantOp::create (
310313 rewriter, loadNdOp->getLoc (), origVectorType,
311314 rewriter.getZeroAttr (origVectorType));
312- // Increse the Y offset for each array slice.
315+ // Increase the Y offset for each array slice.
313316 Value offsetY = convertToValue (rewriter, loadNdOp->getLoc (),
314317 modifiedOffsets.back ());
315318 modifiedOffsets.back () =
@@ -323,19 +326,16 @@ class XeGPULoadNdDescOpPattern final
323326 modifiedOffsets, hwSupportedShape,
324327 cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc ()),
325328 loadNdOp);
326- // // Insert slice to data.
327- // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328- // data, ArrayRef<int64_t>{i});
329- // Bitcast back to original load shape without array length.
329+ // BitCast back to original load shape without array length.
330330 auto bitcastType = VectorType::get (origTensorDescType.getShape (),
331331 origTensorDescType.getElementType ());
332- slice = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
333- bitcastType, slice);
334- arraySlices.push_back (slice);
332+ auto bitCastOp = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
333+ bitcastType, slice);
334+ // BitCastOp must have the same layout as the original loadNdOp.
335+ xegpu::setDistributeLayoutAttr (bitCastOp->getOpResult (0 ),
336+ origTensorDescType.getLayoutAttr ());
337+ arraySlices.push_back (bitCastOp.getResult ());
335338 }
336- // // Cast back to the original type and replace all uses.
337- // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
338- // loadNdOp.getType(), data);
339339 rewriter.replaceOpWithMultiple (loadNdOp, {arraySlices});
340340 return success ();
341341 }
@@ -348,13 +348,12 @@ class XeGPULoadNdDescOpPattern final
348348 hwSupportedShape,
349349 cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc ()),
350350 loadNdOp);
351- auto castOp = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
352- loadNdOp.getType (), data);
353- // // Cast op must have the same layout as the original LoadNdOp result.
354- // xegpu::setDistributeLayoutAttr(
355- // castOp->getOpResult(0),
356- // xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
357- rewriter.replaceOp (loadNdOp, castOp);
351+ auto bitCastOp = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
352+ loadNdOp.getType (), data);
353+ // BitCastOp must have the same layout as the original loadNdOp.
354+ xegpu::setDistributeLayoutAttr (bitCastOp->getOpResult (0 ),
355+ origTensorDescType.getLayoutAttr ());
356+ rewriter.replaceOp (loadNdOp, bitCastOp);
358357 return success ();
359358 }
360359};
@@ -402,15 +401,20 @@ struct XeGPUOptimizeTransposePass final
402401 // converted.
403402 target.addDynamicallyLegalOp <xegpu::CreateNdDescOp>(
404403 [&](xegpu::CreateNdDescOp createNdOp) {
405- return !hasInvalidTranposeLayout (createNdOp.getType ());
404+ return !canBeOptimized (createNdOp.getType ());
406405 });
407406 target.addDynamicallyLegalOp <xegpu::LoadNdOp>(
408407 [&](xegpu::LoadNdOp loadNdOp) {
409- return !hasInvalidTranposeLayout (loadNdOp.getTensorDescType ());
408+ return !canBeOptimized (loadNdOp.getTensorDescType ());
410409 });
411410 target.addDynamicallyLegalOp <vector::ExtractOp>(
412411 [&](vector::ExtractOp extractOp) {
413- return extractOp.getSourceVectorType ().getRank () != 3 ;
412+ auto layout = xegpu::getDistributeLayoutAttr (extractOp.getResult ());
413+ if (!layout)
414+ return true ;
415+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt ();
416+ auto laneData = layout.getEffectiveLaneDataAsInt ();
417+ return !canBeOptimized (laneLayout, laneData);
414418 });
415419 converter.addConversion ([](Type type) { return type; });
416420
0 commit comments