Skip to content
Merged
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
// Bail-out if rank(source) != rank(target). The main limitation here is the
// fact that `ExtractStridedSlice` requires the rank for the input and
// output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return failure();
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,15 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {

// -----

func.func @extract_strided_slice(%arg0: vector<3x2x2xf32>) {
// expected-error@+1 {{expected input vector rank to match target shape rank}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the error in CI:

# .---command stderr------------
# | within split at C:\ws\src\mlir\test\Dialect\Vector\invalid.mlir:769 offset :5:8: error: unexpected error: 'vector.extract_strided_slice' op expected result type to be 'vector<2x2x2xf32>'
# |   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
# |        ^
# | within split at C:\ws\src\mlir\test\Dialect\Vector\invalid.mlir:769 offset :4:6: error: expected error "expected input vector rank to match target shape rank" was not produced
# |   // expected-error@+1 {{expected input vector rank to match target shape rank}}
# |      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# `-----------------------------

I can reproduce this locally on non-Windows machine. @Prakhar-Dixit , does this test pass for you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It failed on a non-Windows machine. Trying to figure it out.
😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the error in CI:

# .---command stderr------------
# | within split at C:\ws\src\mlir\test\Dialect\Vector\invalid.mlir:769 offset :5:8: error: unexpected error: 'vector.extract_strided_slice' op expected result type to be 'vector<2x2x2xf32>'
# |   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
# |        ^
# | within split at C:\ws\src\mlir\test\Dialect\Vector\invalid.mlir:769 offset :4:6: error: expected error "expected input vector rank to match target shape rank" was not produced
# |   // expected-error@+1 {{expected input vector rank to match target shape rank}}
# |      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# `-----------------------------

I can reproduce this locally on non-Windows machine. @Prakhar-Dixit , does this test pass for you?

I have no clue how to handle this. I would be very grateful if you could help 🙏

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double-check - did it use to work? I am wondering whether it's some recent change to MLIR.

Copy link
Contributor Author

@Prakhar-Dixit Prakhar-Dixit Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I guess not, because extracting a 2D slice from a 3D vector is not supported. As a result, an error is thrown indicating that the resulting vector type must be a 3D vector. This behavior is already verified by the following test:

func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
  // expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
  %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
} 

And i guess this test is not required since it's been taken care of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so you've made this error up:

// expected-error@+1 {{expected input vector rank to match target shape rank}}

? :) As in, it wasn't an actual error that you saw? That's what I was going by when reviewing 😅

And i guess this test is not required since it's been taken care of.

Indeed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newbie struglles 🤧

%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

vector<3x2x2xf32> to vector<2x2xf32>
return
}

// -----

#contraction_accesses = [
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>

// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return
}
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<3x2x2xf32>
// CHECK: return

func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
Expand Down