Skip to content

Commit ef51465

Browse files
committed
AffineApplyOp index support
1 parent 0bd3295 commit ef51465

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/SLPVectorizer.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Analysis/DataLayoutAnalysis.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1718
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -142,6 +143,27 @@ static Type getElementType(Operation *op) {
142143
return {};
143144
}
144145

146+
static bool isAdjacentAffineMapIndices(Value idx1, Value idx2) {
147+
auto applyOp1 = idx1.getDefiningOp<affine::AffineApplyOp>();
148+
if (!applyOp1)
149+
return false;
150+
151+
auto applyOp2 = idx2.getDefiningOp<affine::AffineApplyOp>();
152+
if (!applyOp2)
153+
return false;
154+
155+
if (applyOp1.getOperands() != applyOp2.getOperands())
156+
return false;
157+
158+
AffineExpr expr1 = applyOp1.getAffineMap().getResult(0);
159+
AffineExpr expr2 = applyOp2.getAffineMap().getResult(0);
160+
auto diff =
161+
simplifyAffineExpr(expr2 - expr1, 0, applyOp1.getOperands().size());
162+
163+
auto diffConst = dyn_cast<AffineConstantExpr>(diff);
164+
return diffConst && diffConst.getValue() == 1;
165+
}
166+
145167
/// Check if two indices are consecutive, i.e index1 + 1 == index2.
146168
static bool isAdjacentIndices(Value idx1, Value idx2) {
147169
if (auto c1 = getConstantIntValue(idx1)) {
@@ -160,7 +182,9 @@ static bool isAdjacentIndices(Value idx1, Value idx2) {
160182
}
161183
}
162184

163-
// TODO: Handle affine.apply, etc
185+
if (isAdjacentAffineMapIndices(idx1, idx2))
186+
return true;
187+
164188
return false;
165189
}
166190

mlir/test/Dialect/Vector/slp-vectorize.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ func.func @read_write_add_index_interleaved(%arg0: memref<8xi32>, %arg1: memref<
143143
}
144144

145145

146+
#map0 = affine_map<()[s0, s1] -> (s1 * s0)>
147+
#map1 = affine_map<()[s0, s1] -> (s1 * s0 + 1)>
148+
#map2 = affine_map<()[s0, s1] -> (s1 * s0 + 2)>
149+
#map3 = affine_map<()[s0, s1] -> (s1 * s0 + 3)>
150+
151+
// CHECK-LABEL: func @read_write_affine_apply
152+
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
153+
func.func @read_write_affine_apply(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2: index, %arg3: index) {
154+
// CHECK: %[[IDX:.*]] = affine.apply #{{.*}}()[%[[ARG2]], %[[ARG3]]]
155+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
156+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[IDX]]] : memref<8xi32>, vector<4xi32>
157+
158+
%ind0 = affine.apply #map0()[%arg2, %arg3]
159+
%ind1 = affine.apply #map1()[%arg2, %arg3]
160+
%ind2 = affine.apply #map2()[%arg2, %arg3]
161+
%ind3 = affine.apply #map3()[%arg2, %arg3]
162+
163+
%0 = memref.load %arg0[%ind0] : memref<8xi32>
164+
%1 = memref.load %arg0[%ind1] : memref<8xi32>
165+
%2 = memref.load %arg0[%ind2] : memref<8xi32>
166+
%3 = memref.load %arg0[%ind3] : memref<8xi32>
167+
168+
memref.store %0, %arg0[%ind0] : memref<8xi32>
169+
memref.store %1, %arg0[%ind1] : memref<8xi32>
170+
memref.store %2, %arg0[%ind2] : memref<8xi32>
171+
memref.store %3, %arg0[%ind3] : memref<8xi32>
172+
173+
return
174+
}
175+
176+
146177
// CHECK-LABEL: func @read_read_add
147178
// CHECK-SAME: (%[[ARG0:.*]]: memref<8xi32>, %[[ARG1:.*]]: memref<8xi32>)
148179
func.func @read_read_add(%arg0: memref<8xi32>, %arg1: memref<8xi32>) -> (i32, i32, i32, i32) {

0 commit comments

Comments
 (0)