diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6e79085afac9f..39097368b1e71 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2408,6 +2408,7 @@ def Vector_CompressStoreOp : def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0ade9f6..dff66a6e829a9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6233,6 +6233,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..f3d659218b0a8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -75,6 +75,45 @@ static SmallVector sliceLoadStoreIndices(PatternRewriter &rewriter, return indices; } +/// Creates a result tile by extracting individual elements from the source +/// and inserting them at the correct positions in the tile. +static Value createTileFromElements(PatternRewriter &rewriter, Location loc, + Value source, ArrayRef sourceShape, + ArrayRef resultShape, + ArrayRef tileOffsets, + ArrayRef tileShape, + VectorType tileType) { + // Initialize tile with zeros. + Value tile = arith::ConstantOp::create(rewriter, loc, tileType, + rewriter.getZeroAttr(tileType)); + + // Calculate strides for source, result, and tile shapes. + SmallVector sourceStrides = computeStrides(sourceShape); + SmallVector resultStrides = computeStrides(resultShape); + SmallVector tileStrides = computeStrides(tileShape); + int64_t numElementsInTile = computeProduct(tileShape); + + // Iterate over all positions in the tile using linear indexing. + for (int64_t linearTileIdx = 0; linearTileIdx < numElementsInTile; + ++linearTileIdx) { + // Convert linear tile index to multi-dimensional tile position. + SmallVector tilePosition = delinearize(linearTileIdx, tileStrides); + + // Calculate the global position in the result. + SmallVector globalResultPos; + globalResultPos.reserve(tileOffsets.size()); + for (auto [offset, pos] : llvm::zip_equal(tileOffsets, tilePosition)) { + globalResultPos.push_back(offset + pos); + } + + int64_t linearIndex = linearize(globalResultPos, resultStrides); + SmallVector sourcePos = delinearize(linearIndex, sourceStrides); + Value element = vector::ExtractOp::create(rewriter, loc, source, sourcePos); + tile = vector::InsertOp::create(rewriter, loc, element, tile, tilePosition); + } + return tile; +} + // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, @@ -1003,6 +1042,85 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It decomposes a large shape_cast operation +/// into smaller tiles and reconstructs each tile by extracting individual +/// elements from the source vector and placing them at the correct positions. +/// +/// Since shape_cast performs linear element reindexing, the pattern uses +/// linear indexing as a bridge to map between source and result coordinates. +/// For each element in a result tile, it calculates the corresponding source +/// position and extracts that element. +/// +/// Example: +/// Given a shape_cast operation: +/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x2>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32> +/// +/// // First tile [0,0]: elements at result positions +/// (0,0),(0,1),(1,0),(1,1) +/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32> +/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32> +/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32> +/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32> +/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32> +/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32> +/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32> +/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32> +/// %r0 = vector.insert_strided_slice %t3, %zero +/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into +/// vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + Location loc = shapeCastOp.getLoc(); + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + + ArrayRef resultShape = resultType.getShape(); + ArrayRef sourceShape = sourceType.getShape(); + + SmallVector strides(targetShape->size(), 1); + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + // For each unrolled tile in the result. + for (SmallVector tileOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + // Create the target tile type. + auto tileType = + VectorType::get(*targetShape, resultType.getElementType()); + // Build the tile by extracting individual elements. + Value tile = createTileFromElements( + rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape, + tileOffsets, *targetShape, tileType); + // Insert the tile into the result. + result = rewriter.createOrFold( + loc, tile, result, tileOffsets, strides); + } + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1131,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f33..a238cc85fc2f3 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,61 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + +func.func @shape_cast_1D_to_2D(%v: vector<8xf32>) -> vector<4x2xf32> { + %0 = vector.shape_cast %v : vector<8xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// CHECK-LABEL: func @shape_cast_1D_to_2D +// CHECK-SAME: (%[[V:.*]]: vector<8xf32>) -> vector<4x2xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x2xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[V]][0] : f32 from vector<8xf32> +// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[V]][1] : f32 from vector<8xf32> +// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E2:.*]] = vector.extract %[[V]][2] : f32 from vector<8xf32> +// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E3:.*]] = vector.extract %[[V]][3] : f32 from vector<8xf32> +// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK: %[[E4:.*]] = vector.extract %[[V]][4] : f32 from vector<8xf32> +// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E5:.*]] = vector.extract %[[V]][5] : f32 from vector<8xf32> +// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E6:.*]] = vector.extract %[[V]][6] : f32 from vector<8xf32> +// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E7:.*]] = vector.extract %[[V]][7] : f32 from vector<8xf32> +// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK: return %[[I1]] : vector<4x2xf32> + +func.func @shape_cast_2D(%v: vector<2x4xf32>) -> vector<4x2xf32> { + %0 = vector.shape_cast %v : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[V:.*]]: vector<2x4xf32>) -> vector<4x2xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x2xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[V]][0, 0] : f32 from vector<2x4xf32> +// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[V]][0, 1] : f32 from vector<2x4xf32> +// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E2:.*]] = vector.extract %[[V]][0, 2] : f32 from vector<2x4xf32> +// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E3:.*]] = vector.extract %[[V]][0, 3] : f32 from vector<2x4xf32> +// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK: %[[E4:.*]] = vector.extract %[[V]][1, 0] : f32 from vector<2x4xf32> +// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E5:.*]] = vector.extract %[[V]][1, 1] : f32 from vector<2x4xf32> +// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E6:.*]] = vector.extract %[[V]][1, 2] : f32 from vector<2x4xf32> +// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E7:.*]] = vector.extract %[[V]][1, 3] : f32 from vector<2x4xf32> +// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x2xf32> +// CHECK: return %[[I1]] : vector<4x2xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda71..0a54f06f5d6b6 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -163,8 +163,8 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success( isa( - op)); + vector::BroadcastOp, vector::LoadOp, vector::StoreOp, + vector::ShapeCastOp>(op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions()