Skip to content

Commit 99ba2a8

Browse files
committed
vector.step support
1 parent 2192f7a commit 99ba2a8

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5190,6 +5190,9 @@ static LogicalResult isContiguousIndices(Value indexVec) {
51905190
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
51915191
return failure();
51925192

5193+
if (indexVec.getDefiningOp<StepOp>())
5194+
return success();
5195+
51935196
DenseIntElementsAttr elements;
51945197
if (!matchPattern(indexVec, m_Constant(&elements)))
51955198
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,6 +2874,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
28742874

28752875
// -----
28762876

2877+
// CHECK-LABEL: @contiguous_gather_step
2878+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2879+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2880+
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2881+
// CHECK: return %[[R]]
2882+
func.func @contiguous_gather_step(%base: memref<?xf32>,
2883+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2884+
%c0 = arith.constant 0 : index
2885+
%indices = vector.step : vector<16xindex>
2886+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2887+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2888+
return %1 : vector<16xf32>
2889+
}
2890+
2891+
// -----
2892+
28772893
// CHECK-LABEL: @contiguous_scatter
28782894
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
28792895
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -2902,3 +2918,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
29022918
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
29032919
return
29042920
}
2921+
2922+
// -----
2923+
2924+
// CHECK-LABEL: @contiguous_scatter_step
2925+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2926+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2927+
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2928+
func.func @contiguous_scatter_step(%base: memref<?xf32>,
2929+
%mask: vector<16xi1>, %value: vector<16xf32>) {
2930+
%c0 = arith.constant 0 : index
2931+
%indices = vector.step : vector<16xindex>
2932+
vector.scatter %base[%c0][%indices], %mask, %value :
2933+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
2934+
return
2935+
}

0 commit comments

Comments
 (0)