@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
395395 }
396396};
397397
398- // / This pattern converts the ExtractOp to a ShuffleOp that works on a
399- // / linearized vector.
400- // / Following,
401- // / vector.extract %source [ position ]
402- // / is converted to :
403- // / %source_1d = vector.shape_cast %source
404- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405- // / %out_nd = vector.shape_cast %out_1d
406- // / `shuffle_indices_1d` is computed using the position of the original extract.
398+ // / This pattern linearizes `vector.extract` operations. It generates a 1-D
399+ // / version of the `vector.extract` operation when extracting a scalar from a
400+ // / vector. It generates a 1-D `vector.shuffle` operation when extracting a
401+ // / subvector from a larger vector.
402+ // /
403+ // / Example #1:
404+ // /
405+ // / %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406+ // /
407+ // / is converted to:
408+ // /
409+ // / %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410+ // / %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411+ // / 24, 25, 26, 27, 28, 29, 30, 31] :
412+ // / vector<32xf32>, vector<32xf32>
413+ // / %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414+ // /
415+ // / Example #2:
416+ // /
417+ // / %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418+ // /
419+ // / is converted to:
420+ // /
421+ // / %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422+ // / %1 = vector.extract %0[6] : i32 from vector<8xi32>
423+ // /
407424struct LinearizeVectorExtract final
408425 : public OpConversionPattern<vector::ExtractOp> {
409426 using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
413430 LogicalResult
414431 matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
415432 ConversionPatternRewriter &rewriter) const override {
416- // Skip if result is not a vector type
417- if (!isa<VectorType>(extractOp.getType ()))
418- return rewriter.notifyMatchFailure (extractOp,
419- " scalar extract not supported" );
420433 Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
421434 assert (dstTy && " expected 1-D vector type" );
422435
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
436449 linearizedOffset += offsets[i] * size;
437450 }
438451
452+ if (!isa<VectorType>(extractOp.getType ())) {
453+ // Scalar case: generate a 1-D extract.
454+ Value result = rewriter.createOrFold <vector::ExtractOp>(
455+ extractOp.getLoc (), adaptor.getVector (), linearizedOffset);
456+ rewriter.replaceOp (extractOp, result);
457+ return success ();
458+ }
459+
460+ // Vector case: generate a shuffle.
461+
439462 llvm::SmallVector<int64_t , 2 > indices (size);
440463 std::iota (indices.begin (), indices.end (), linearizedOffset);
441464 rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
0 commit comments