Skip to content

Commit 62bf14d

Browse files
committed
add const mask tests
1 parent 9641898 commit 62bf14d

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2857,14 +2857,47 @@ func.func @contiguous_gather(%base: memref<?xf32>,
28572857

28582858
// -----
28592859

2860+
// CHECK-LABEL: @contiguous_gather_const_mask
2861+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>)
2862+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2863+
// CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
2864+
// CHECK: return %[[R]]
2865+
func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
2866+
%passthru: vector<16xf32>) -> vector<16xf32> {
2867+
%c0 = arith.constant 0 : index
2868+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2869+
%mask = arith.constant dense<true> : vector<16xi1>
2870+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2871+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2872+
return %1 : vector<16xf32>
2873+
}
2874+
2875+
// -----
2876+
28602877
// CHECK-LABEL: @contiguous_scatter
28612878
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
28622879
// CHECK: %[[C0:.*]] = arith.constant 0 : index
28632880
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
28642881
func.func @contiguous_scatter(%base: memref<?xf32>,
2865-
%mask: vector<16xi1>, %value: vector<16xf32>){
2882+
%mask: vector<16xi1>, %value: vector<16xf32>) {
2883+
%c0 = arith.constant 0 : index
2884+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2885+
vector.scatter %base[%c0][%indices], %mask, %value :
2886+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
2887+
return
2888+
}
2889+
2890+
// -----
2891+
2892+
// CHECK-LABEL: @contiguous_scatter_const_mask
2893+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>)
2894+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2895+
// CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
2896+
func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
2897+
%value: vector<16xf32>) {
28662898
%c0 = arith.constant 0 : index
28672899
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2900+
%mask = vector.constant_mask [16] : vector<16xi1>
28682901
vector.scatter %base[%c0][%indices], %mask, %value :
28692902
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
28702903
return

0 commit comments

Comments
 (0)