Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :

def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

LogicalResult ShapeCastOp::verify() {

VectorType sourceType = getSourceVectorType();
Expand Down
193 changes: 191 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};

static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
ArrayRef<int64_t> resultShape) {
if (targetShape.size() > resultShape.size())
return false;

int64_t targetElements = ShapedType::getNumElements(targetShape);
int64_t resultElements = ShapedType::getNumElements(resultShape);

// Result must be evenly divisible by target.
if (resultElements % targetElements != 0)
return false;

// For contiguous extraction, we need to be able to
// extract targetElements contiguously from the result shape.
// This means we can "consume" dimensions from the innermost outward
// until we have exactly targetElements.

int64_t remainingElements = targetElements;
int targetDimIdx = targetShape.size() - 1;

// Work backwards through result dimensions.
for (int resultDimIdx = resultShape.size() - 1;
resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
--resultDimIdx) {

int64_t resultDimSize = resultShape[resultDimIdx];
int64_t targetDimSize = targetShape[targetDimIdx];

if (targetDimSize > resultDimSize)
return false;

if (targetDimSize == resultDimSize) {
if (remainingElements % targetDimSize != 0)
return false;
remainingElements /= targetDimSize;
--targetDimIdx;
} else {
if (remainingElements != targetDimSize)
return false;
remainingElements = 1;
--targetDimIdx;
}
}

// Check remaining target dimensions are all 1 and we consumed all elements
return remainingElements == 1 &&
(targetDimIdx < 0 || llvm::all_of(
targetShape.take_front(targetDimIdx + 1),
[](int64_t d) { return d == 1; }));
}

// Calculate the shape to extract from source.
static std::optional<SmallVector<int64_t>>
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
int64_t targetElements) {
SmallVector<int64_t> extractShape;
int64_t remainingElements = targetElements;

// Build extract shape from innermost dimension outward to ensure contiguity.
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
extractShape.insert(extractShape.begin(), takeFromDim);

if (remainingElements % takeFromDim != 0)
return std::nullopt; // Not evenly divisible.
remainingElements /= takeFromDim;
}

// Fill remaining dimensions with 1.
while (extractShape.size() < sourceShape.size())
extractShape.insert(extractShape.begin(), 1);

if (ShapedType::getNumElements(extractShape) != targetElements)
return std::nullopt;

return extractShape;
}

// Convert result offsets to source offsets via linear position.
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
ArrayRef<int64_t> sourceStrides,
ArrayRef<int64_t> resultStrides) {
// Convert result offsets to linear position.
int64_t linearIndex = linearize(resultOffsets, resultStrides);
// Convert linear position to source offsets.
return delinearize(linearIndex, sourceStrides);
}

/// This pattern unrolls `vector.shape_cast` operations according to the
/// provided target unroll shape. It unrolls a large shape cast into smaller
/// shape casts by extracting contiguous slices from the source vector, casting
/// each slice to the target shape, and assembling the result by inserting each
/// computed segment into the appropriate offset of the result vector.
///
/// This pattern only applies when contiguous slices can be extracted from the
/// source vector and inserted into the result vector such that each slice
/// remains a valid vector (and not decompose to scalars). In these cases, the
/// unrolling proceeds as:
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
/// vector.insert_strided_slice.
///
/// Example:
/// Given a shape cast operation:
/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
///
/// and a target unroll shape of <2x4>, the pattern produces:
///
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
///
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
UnrollShapeCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, shapeCastOp);
if (!targetShape)
return failure();

VectorType sourceType = shapeCastOp.getSourceVectorType();
VectorType resultType = shapeCastOp.getResultVectorType();
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();

if (!isContiguousExtract(*targetShape, resultShape))
return rewriter.notifyMatchFailure(shapeCastOp,
"Only supports cases where contiguous "
"extraction is possible");

int64_t targetElements = ShapedType::getNumElements(*targetShape);

// Calculate the shape to extract from source.
std::optional<SmallVector<int64_t>> extractShape =
calculateSourceExtractShape(sourceShape, targetElements);
if (!extractShape)
return rewriter.notifyMatchFailure(
shapeCastOp,
"cannot extract target number of elements contiguously from source");

Location loc = shapeCastOp.getLoc();

// Create result vector initialized to zero.
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));

VectorType targetType =
VectorType::get(*targetShape, sourceType.getElementType());

SmallVector<int64_t> extractStrides(extractShape->size(), 1);
SmallVector<int64_t> insertStrides(targetShape->size(), 1);
SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
SmallVector<int64_t> resultStrides = computeStrides(resultShape);

for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
SmallVector<int64_t> sourceOffsets =
calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
extractStrides);
Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
loc, targetType, sourceChunk);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, targetChunk, result, resultOffsets, insertStrides);
}

rewriter.replaceOp(shapeCastOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
Expand All @@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
options, benefit);
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
patterns.getContext(), options, benefit);
}

void mlir::vector::populateVectorToElementsUnrollPatterns(
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return


func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
return %0 : vector<2x2x4xf32>
}

// CHECK-LABEL: func @shape_cast_1D
// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
// CHECK: return %[[I1]] : vector<2x2x4xf32>


func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
%0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// CHECK-LABEL: func @shape_cast_2D
// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
// CHECK: return %[[I1]] : vector<4x4xf32>
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 4})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ShapeCastOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
Expand Down
Loading