Skip to content

Commit 5127337

Browse files
committed
[mlir][Vector] Infer mask and pass_thru types for maskedload/store
1 parent 3cdf6e7 commit 5127337

18 files changed

+157
-154
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> {
18401840
}
18411841

18421842
def Vector_MaskedLoadOp :
1843-
Vector_Op<"maskedload">,
1843+
Vector_Op<"maskedload", [
1844+
AllTypesMatch<["result", "pass_thru"]>,
1845+
TypesMatchWith<"mask shape should match result shape",
1846+
"result",
1847+
"mask",
1848+
"VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
1849+
"IntegerType::get($_ctxt, 1),"
1850+
"::llvm::cast<VectorType>($_self).getScalableDims())">,
1851+
AllElementTypesMatch<["result", "base"]>
1852+
]>,
18441853
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18451854
Variadic<Index>:$indices,
18461855
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp :
18751884

18761885
```mlir
18771886
%0 = vector.maskedload %base[%i], %mask, %pass_thru
1878-
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
1887+
: memref<?xf32>, vector<8xf32>
18791888

18801889
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
1881-
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1890+
: memref<?x?xf32>, vector<16xf32>
18821891
```
18831892
}];
18841893
let extraClassDeclaration = [{
@@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp :
18961905
}
18971906
}];
18981907
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
1899-
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
1908+
"type($base) `,` type($result)";
19001909
let hasCanonicalizer = 1;
19011910
let hasFolder = 1;
19021911
let hasVerifier = 1;
19031912
}
19041913

19051914
def Vector_MaskedStoreOp :
1906-
Vector_Op<"maskedstore">,
1915+
Vector_Op<"maskedstore", [
1916+
TypesMatchWith<"mask shape should match result shape",
1917+
"valueToStore",
1918+
"mask",
1919+
"VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
1920+
"IntegerType::get($_ctxt, 1),"
1921+
"::llvm::cast<VectorType>($_self).getScalableDims())">,
1922+
AllElementTypesMatch<["valueToStore", "base"]>
1923+
]>,
19071924
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19081925
Variadic<Index>:$indices,
19091926
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp :
19371954

19381955
```mlir
19391956
vector.maskedstore %base[%i], %mask, %value
1940-
: memref<?xf32>, vector<8xi1>, vector<8xf32>
1957+
: memref<?xf32>, vector<8xf32>
19411958

19421959
vector.maskedstore %base[%i, %j], %mask, %value
1943-
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
1960+
: memref<?x?xf32>, vector<16xf32>
19441961
```
19451962
}];
19461963
let extraClassDeclaration = [{
@@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp :
19561973
}];
19571974
let assemblyFormat =
19581975
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
1959-
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
1976+
"attr-dict `:` type($base) `,` type($valueToStore)";
19601977
let hasCanonicalizer = 1;
19611978
let hasFolder = 1;
19621979
let hasVerifier = 1;

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
51275127
//===----------------------------------------------------------------------===//
51285128

51295129
LogicalResult MaskedLoadOp::verify() {
5130-
VectorType maskVType = getMaskVectorType();
5131-
VectorType passVType = getPassThruVectorType();
5132-
VectorType resVType = getVectorType();
5133-
MemRefType memType = getMemRefType();
5134-
5135-
if (resVType.getElementType() != memType.getElementType())
5136-
return emitOpError("base and result element type should match");
5137-
if (llvm::size(getIndices()) != memType.getRank())
5138-
return emitOpError("requires ") << memType.getRank() << " indices";
5139-
if (resVType.getShape() != maskVType.getShape())
5140-
return emitOpError("expected result shape to match mask shape");
5141-
if (resVType != passVType)
5142-
return emitOpError("expected pass_thru of same type as result type");
5130+
int64_t memRank = getMemRefType().getRank();
5131+
if (llvm::size(getIndices()) != memRank)
5132+
return emitOpError("requires ") << memRank << " indices";
51435133
return success();
51445134
}
51455135

@@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
51815171
//===----------------------------------------------------------------------===//
51825172

51835173
LogicalResult MaskedStoreOp::verify() {
5184-
VectorType maskVType = getMaskVectorType();
5185-
VectorType valueVType = getVectorType();
5186-
MemRefType memType = getMemRefType();
5187-
5188-
if (valueVType.getElementType() != memType.getElementType())
5189-
return emitOpError("base and valueToStore element type should match");
5190-
if (llvm::size(getIndices()) != memType.getRank())
5191-
return emitOpError("requires ") << memType.getRank() << " indices";
5192-
if (valueVType.getShape() != maskVType.getShape())
5193-
return emitOpError("expected valueToStore shape to match mask shape");
5174+
int64_t memRank = getMemRefType().getRank();
5175+
if (llvm::size(getIndices()) != memRank)
5176+
return emitOpError("requires ") << memRank << " indices";
51945177
return success();
51955178
}
51965179

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
8888
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
8989
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
9090
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
91-
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
91+
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi32>
9292
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
9393
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
9494
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
18911891

18921892
func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
18931893
%c0 = arith.constant 0: index
1894-
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1894+
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
18951895
return %0 : vector<16xf32>
18961896
}
18971897

@@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
19061906

19071907
func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
19081908
%c0 = arith.constant 0: index
1909-
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
1909+
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
19101910
return %0 : vector<[16]xf32>
19111911
}
19121912

@@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %a
19211921

19221922
func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
19231923
%c0 = arith.constant 0: index
1924-
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
1924+
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
19251925
return %0 : vector<16xindex>
19261926
}
19271927
// CHECK-LABEL: func @masked_load_index
@@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2
19311931

19321932
func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
19331933
%c0 = arith.constant 0: index
1934-
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
1934+
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
19351935
return %0 : vector<[16]xindex>
19361936
}
19371937
// CHECK-LABEL: func @masked_load_index_scalable
@@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
19451945

19461946
func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
19471947
%c0 = arith.constant 0: index
1948-
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
1948+
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
19491949
return
19501950
}
19511951

@@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
19591959

19601960
func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
19611961
%c0 = arith.constant 0: index
1962-
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
1962+
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
19631963
return
19641964
}
19651965

@@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %
19731973

19741974
func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
19751975
%c0 = arith.constant 0: index
1976-
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
1976+
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
19771977
return
19781978
}
19791979
// CHECK-LABEL: func @masked_store_index
@@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg
19831983

19841984
func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
19851985
%c0 = arith.constant 0: index
1986-
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
1986+
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
19871987
return
19881988
}
19891989
// CHECK-LABEL: func @masked_store_index_scalable

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ func.func @fold_vector_load_subview(
837837
func.func @fold_vector_maskedload_subview(
838838
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
839839
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
840-
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
840+
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
841841
return %1 : vector<32xf32>
842842
}
843843

@@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview(
847847
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
848848
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
849849
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
850-
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
850+
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
851851

852852
// -----
853853

@@ -871,7 +871,7 @@ func.func @fold_vector_store_subview(
871871
func.func @fold_vector_maskedstore_subview(
872872
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
873873
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
874-
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
874+
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
875875
return
876876
}
877877

@@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview(
881881
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
882882
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
883883
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
884-
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
884+
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
885885
// CHECK: return
886886

887887
// -----
@@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape(
907907
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
908908
%c0 = arith.constant 0 : index
909909
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
910-
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
910+
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
911911
return %1 : vector<8xf32>
912912
}
913913

@@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape(
943943
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
944944
%c0 = arith.constant 0 : index
945945
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
946-
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
946+
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
947947
return
948948
}
949949

@@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape(
979979
func.func @fold_vector_maskedload_collapse_shape(
980980
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
981981
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
982-
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
982+
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
983983
return %1 : vector<8xf32>
984984
}
985985

@@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape(
10171017
func.func @fold_vector_maskedstore_collapse_shape(
10181018
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
10191019
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
1020-
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
1020+
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
10211021
return
10221022
}
10231023

0 commit comments

Comments
 (0)