Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4122,6 +4122,10 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());

// Turn a 0-D mask into a single-element 1-D mask.
Copy link
Contributor

@banach-space banach-space Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Personally, I'd appreciate a note saying that ATM vector.mask does not support 0-D masks (enforced here). And that's basically the reason to "upgrade" this mask (which to me is a very good reason).

As in, document "why" rather than "what" :)

if (maskShape.empty())
maskShape.push_back(1);

SmallVector<bool> scalableDims =
applyPermutationMap(invPermMap, vecType.getScalableDims());

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,22 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>

// -----

// We can support 0-D masks if eventually needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Why not add this comment where this restriction is enforced?

let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);

And to the Op docs :)

func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
%idx0: index, %idx1: index,
%m0: vector<i1>) -> vector<1x1x4xi32> {
%cst = arith.constant 0 : i32
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
%res = vector.mask %m0 {
%0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
: tensor<2x4xi32>, vector<1x1x4xi32>
vector.yield %0 : vector<1x1x4xi32>
} : vector<i1> -> vector<1x1x4xi32>
return %res : vector<1x1x4xi32>
}

// -----

func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
%idx0: index, %idx1: index,
%m0: vector<1xi1>) -> vector<1x1x4xi32> {
%cst = arith.constant 0 : i32
// CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
%res = vector.mask %m0 {
%0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
: tensor<2x4xi32>, vector<1x1x4xi32>
vector.yield %0 : vector<1x1x4xi32>
} : vector<1xi1> -> vector<1x1x4xi32>
return %res : vector<1x1x4xi32>
}

// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
Expand Down
Loading