Skip to content

Commit 34f89f7

Browse files
committed
[mlir][vector] Fix parser of vector.transfer_read
1 parent 8658896 commit 34f89f7

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,27 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151151
return false;
152152
}
153153

154+
/// Returns the number of dimensions of the `shapedType` that participate in the
155+
/// vector transfer, effectively the rank of the vector dimensions within the
156+
/// `shapedType`. This is calculated by taking the rank of the `vectorType`
157+
/// being transferred and subtracting the rank of the `shapedType`'s element
158+
/// type if it's also a vector.
159+
///
160+
/// This is used to determine the number of minor dimensions for identity maps
161+
/// in vector transfers.
162+
///
163+
/// For example, given a transfer operation involving `shapedType` and
164+
/// `vectorType`:
165+
///
166+
/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167+
/// - shapedType.getElementType() = f32 (rank 0)
168+
/// - vectorType.getRank() = 2
169+
/// - Result = 2 - 0 = 2
170+
///
171+
/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172+
/// - shapedType.getElementType() = vector<20xf32> (rank 1)
173+
/// - vectorType.getRank() = 1
174+
/// - Result = 1 - 1 = 0
154175
static unsigned getRealVectorRank(ShapedType shapedType,
155176
VectorType vectorType) {
156177
unsigned elementVectorRank = 0;

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,6 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
525525

526526
// -----
527527

528-
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
529-
%c3 = arith.constant 3 : index
530-
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
531-
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
532-
return %0 : vector<3x4xi32>
533-
}
534-
535-
// -----
536-
537528
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
538529
%c3 = arith.constant 3 : index
539530
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
@@ -664,14 +655,6 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
664655

665656
// -----
666657

667-
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
668-
%c3 = arith.constant 3 : index
669-
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
670-
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
671-
}
672-
673-
// -----
674-
675658
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
676659
%c3 = arith.constant 3 : index
677660
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}

0 commit comments

Comments
 (0)