Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,7 @@ def Vector_CompressStoreOp :

def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6233,6 +6233,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

LogicalResult ShapeCastOp::verify() {

VectorType sourceType = getSourceVectorType();
Expand Down
151 changes: 149 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};

/// This pattern unrolls `vector.shape_cast` operations according to the
/// provided target unroll shape. It decomposes a large shape_cast operation
/// into smaller tiles and reconstructs each tile by extracting individual
/// elements from the source vector and placing them at the correct positions.
///
/// Since shape_cast performs linear element reindexing, the pattern uses
/// linear indexing as a bridge to map between source and result coordinates.
/// For each element in a result tile, it calculates the corresponding source
/// position and extracts that element.
///
/// Example:
/// Given a shape_cast operation:
/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32>
///
/// and a target unroll shape of <2x2>, the pattern produces:
///
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32>
///
/// // First tile [0,0]: elements at result positions
/// (0,0),(0,1),(1,0),(1,1)
/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32>
/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32>
/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32>
/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32>
/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32>
/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32>
/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32>
/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32>
/// %r0 = vector.insert_strided_slice %t3, %zero
/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into
/// vector<4x4xf32>
///
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
UnrollShapeCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, shapeCastOp);
if (!targetShape)
return failure();

Location loc = shapeCastOp.getLoc();
VectorType sourceType = shapeCastOp.getSourceVectorType();
VectorType resultType = shapeCastOp.getResultVectorType();

ArrayRef<int64_t> resultShape = resultType.getShape();
ArrayRef<int64_t> sourceShape = sourceType.getShape();

SmallVector<int64_t> strides(targetShape->size(), 1);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));

// For each unrolled tile in the result
for (SmallVector<int64_t> tileOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {

// Create the target tile type
VectorType tileType =
VectorType::get(*targetShape, resultType.getElementType());

// Build the tile by extracting individual elements
Value tile = createTileFromElements(
rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape,
tileOffsets, *targetShape, tileType);

// Insert the tile into the result
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, tile, result, tileOffsets, strides);
}

rewriter.replaceOp(shapeCastOp, result);
return success();
}

private:
/// Creates a result tile by extracting individual elements from the source
/// and inserting them at the correct positions in the tile.
Value createTileFromElements(PatternRewriter &rewriter, Location loc,
Value source, ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
ArrayRef<int64_t> tileOffsets,
ArrayRef<int64_t> tileShape,
VectorType tileType) const {

// Initialize tile with zeros
Value tile = rewriter.create<arith::ConstantOp>(
loc, tileType, rewriter.getZeroAttr(tileType));

// Calculate strides for both source and result shapes
SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
SmallVector<int64_t> resultStrides = computeStrides(resultShape);

// Iterate over all positions in the tile using linear indexing
for (int64_t linearTileIdx = 0; linearTileIdx < computeProduct(tileShape);
++linearTileIdx) {
// Convert linear tile index to multi-dimensional tile position
SmallVector<int64_t> tilePosition =
linearIndexToMultiDim(linearTileIdx, tileShape);

// Calculate the global position in the result
SmallVector<int64_t> globalResultPos;
globalResultPos.reserve(tileOffsets.size());
for (auto [offset, pos] : llvm::zip(tileOffsets, tilePosition)) {
globalResultPos.push_back(offset + pos);
}

// Convert result position to linear index
int64_t linearIndex = linearize(globalResultPos, resultStrides);

// Convert linear index to source position
SmallVector<int64_t> sourcePos =
linearIndexToMultiDim(linearIndex, sourceShape);

// Extract element from source
Value element =
rewriter.create<vector::ExtractOp>(loc, source, sourcePos);

// Insert element into tile
tile =
rewriter.create<vector::InsertOp>(loc, element, tile, tilePosition);
}

return tile;
}

/// Converts a linear index to multi-dimensional position within a given
/// shape.
SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex,
ArrayRef<int64_t> shape) const {
SmallVector<int64_t> position(shape.size());

for (int64_t i = shape.size() - 1; i >= 0; --i) {
position[i] = linearIndex % shape[i];
linearIndex /= shape[i];
}

return position;
}

vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
Expand All @@ -1013,8 +1160,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
options, benefit);
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
patterns.getContext(), options, benefit);
}

void mlir::vector::populateVectorToElementsUnrollPatterns(
Expand Down
92 changes: 92 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,95 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return

//CHECK-LABEL: func @shape_cast_1D_to_2D
// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<4x4xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<16xf32>
// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<16xf32>
// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][4] : f32 from vector<16xf32>
// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][5] : f32 from vector<16xf32>
// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][2] : f32 from vector<16xf32>
// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][3] : f32 from vector<16xf32>
// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][6] : f32 from vector<16xf32>
// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][7] : f32 from vector<16xf32>
// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][8] : f32 from vector<16xf32>
// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][9] : f32 from vector<16xf32>
// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][12] : f32 from vector<16xf32>
// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][13] : f32 from vector<16xf32>
// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][10] : f32 from vector<16xf32>
// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][11] : f32 from vector<16xf32>
// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][14] : f32 from vector<16xf32>
// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][15] : f32 from vector<16xf32>
// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: return %[[I3]] : vector<4x4xf32>
func.func @shape_cast_1D_to_2D(%v: vector<16xf32>) -> vector<4x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}

//CHECK-LABEL: func @shape_cast_2D
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>) -> vector<4x4xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<2x8xf32>
// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][0, 1] : f32 from vector<2x8xf32>
// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][0, 4] : f32 from vector<2x8xf32>
// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][0, 5] : f32 from vector<2x8xf32>
// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][0, 2] : f32 from vector<2x8xf32>
// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][0, 3] : f32 from vector<2x8xf32>
// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][0, 6] : f32 from vector<2x8xf32>
// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][0, 7] : f32 from vector<2x8xf32>
// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<2x8xf32>
// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][1, 1] : f32 from vector<2x8xf32>
// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][1, 4] : f32 from vector<2x8xf32>
// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][1, 5] : f32 from vector<2x8xf32>
// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][1, 2] : f32 from vector<2x8xf32>
// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][1, 3] : f32 from vector<2x8xf32>
// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][1, 6] : f32 from vector<2x8xf32>
// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][1, 7] : f32 from vector<2x8xf32>
// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
// CHECK: return %[[I3]] : vector<4x4xf32>
func.func @shape_cast_2D(%v: vector<2x8xf32>) -> vector<4x4xf32> {
%0 = vector.shape_cast %v : vector<2x8xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
op));
vector::BroadcastOp, vector::LoadOp, vector::StoreOp,
vector::ShapeCastOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
Expand Down