@@ -296,18 +296,18 @@ class XeGPULoadNdDescOpPattern final
296296 Value data;
297297 // Orig data shape is 3D for the array length case.
298298 if (origTensorDescType.getArrayLength () > 1 ) {
299- SmallVector<int64_t > arrayLenDataShape (origDataShape);
300- arrayLenDataShape.insert (arrayLenDataShape.begin (),
301- origTensorDescType.getArrayLength ());
302- auto arrayLenVecType =
303- VectorType::get (arrayLenDataShape, adaptorType.getElementType ());
304- data = arith::ConstantOp::create (rewriter, loadNdOp->getLoc (),
305- arrayLenVecType,
306- rewriter.getZeroAttr (arrayLenVecType));
299+ // SmallVector<int64_t> arrayLenDataShape(origDataShape);
300+ // arrayLenDataShape.insert(arrayLenDataShape.begin(),
301+ // origTensorDescType.getArrayLength());
302+ // auto arrayLenVecType =
303+ // VectorType::get(arrayLenDataShape, adaptorType.getElementType());
304+ // auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
305+ // arrayLenVecType,
306+ // rewriter.getZeroAttr(arrayLenVecType));
307+ SmallVector<Value> arraySlices;
307308 for (int64_t i = 0 ; i < origTensorDescType.getArrayLength (); ++i) {
308309 Value slice = arith::ConstantOp::create (
309- rewriter, loadNdOp->getLoc (),
310- VectorType::get (origDataShape, adaptorType.getElementType ()),
310+ rewriter, loadNdOp->getLoc (), origVectorType,
311311 rewriter.getZeroAttr (origVectorType));
312312 // Increse the Y offset for each array slice.
313313 Value offsetY = convertToValue (rewriter, loadNdOp->getLoc (),
@@ -323,14 +323,20 @@ class XeGPULoadNdDescOpPattern final
323323 modifiedOffsets, hwSupportedShape,
324324 cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc ()),
325325 loadNdOp);
326- // Insert slice to data.
327- data = vector::InsertOp::create (rewriter, loadNdOp->getLoc (), slice,
328- data, ArrayRef<int64_t >{i});
326+ // // Insert slice to data.
327+ // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
328+ // data, ArrayRef<int64_t>{i});
329+ // Bitcast back to original load shape without array length.
330+ auto bitcastType = VectorType::get (origTensorDescType.getShape (),
331+ origTensorDescType.getElementType ());
332+ slice = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
333+ bitcastType, slice);
334+ arraySlices.push_back (slice);
329335 }
330- // Cast back to the original type and replace all uses.
331- data = vector::BitCastOp::create (rewriter, loadNdOp->getLoc (),
332- loadNdOp.getType (), data);
333- rewriter.replaceOp (loadNdOp, data );
336+ // // Cast back to the original type and replace all uses.
337+ // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
338+ // loadNdOp.getType(), data);
339+ rewriter.replaceOpWithMultiple (loadNdOp, {arraySlices} );
334340 return success ();
335341 }
336342 data = arith::ConstantOp::create (
@@ -352,12 +358,33 @@ class XeGPULoadNdDescOpPattern final
352358 return success ();
353359 }
354360};
361+
362+ class VectorExtractOpPattern final
363+ : public OpConversionPattern<vector::ExtractOp> {
364+ public:
365+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
366+ LogicalResult
367+ matchAndRewrite (vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
368+ ConversionPatternRewriter &rewriter) const override {
369+ if (adaptor.getSource ().size () == 1 )
370+ return failure ();
371+ auto mixedPos = extractOp.getMixedPosition ();
372+ if (mixedPos.size () != 1 )
373+ return failure ();
374+ auto mayBeInt = getConstantIntValue (mixedPos[0 ]);
375+ if (!mayBeInt)
376+ return failure ();
377+ rewriter.replaceOp (extractOp, adaptor.getSource ()[*mayBeInt]);
378+ return success ();
379+ }
380+ };
381+
355382} // namespace
356383
357384void xegpu::populateXeGPUOptimizeTransposePatterns (
358385 RewritePatternSet &patterns) {
359- patterns.add <XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
360- patterns.getContext ());
386+ patterns.add <XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
387+ VectorExtractOpPattern>( patterns.getContext ());
361388}
362389
363390namespace {
@@ -381,6 +408,10 @@ struct XeGPUOptimizeTransposePass final
381408 [&](xegpu::LoadNdOp loadNdOp) {
382409 return !hasInvalidTranposeLayout (loadNdOp.getTensorDescType ());
383410 });
411+ target.addDynamicallyLegalOp <vector::ExtractOp>(
412+ [&](vector::ExtractOp extractOp) {
413+ return extractOp.getSourceVectorType ().getRank () != 3 ;
414+ });
384415 converter.addConversion ([](Type type) { return type; });
385416
386417 target.addLegalDialect <arith::ArithDialect, memref::MemRefDialect,
0 commit comments