Skip to content

Commit b225b34

Browse files
committed
Address comments and add negative scalar unit dim test
1 parent 0b1a2e8 commit b225b34

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCF.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
111111
/// `replacement` value with a different type. A callback is used to insert
112112
/// cast ops inside the block to account for type differences.
113113
using ValueTypeCastFnTy =
114-
std::function<Value(OpBuilder &, Location loc, Type, Value)>;
114+
llvm::function_ref<Value(OpBuilder &, Location loc, Type, Value)>;
115115
SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
116116
scf::ForOp forOp,
117117
OpOperand &operand,

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,8 +1837,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
18371837
continue;
18381838

18391839
// Create a new ForOp with that iter operand replaced.
1840-
mlir::scf::ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc,
1841-
Type type, Value source) {
1840+
auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
18421841
return b.create<vector::ShapeCastOp>(loc, type, source);
18431842
};
18441843

mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,25 @@ func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %
282282
// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
283283
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
284284
// CHECK: return %[[CASTBACK]]
285+
286+
// -----
287+
288+
func.func @scf_for_with_scalable_unit_dims(%vec: vector<1x[1]xf32>) -> vector<1x[1]xf32> {
289+
%c0 = arith.constant 0 : index
290+
%c1 = arith.constant 1 : index
291+
%c4 = arith.constant 4 : index
292+
%res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x[1]xf32> {
293+
%s = math.sqrt %iter : vector<1x[1]xf32>
294+
scf.yield %s : vector<1x[1]xf32>
295+
}
296+
return %res : vector<1x[1]xf32>
297+
}
298+
299+
// CHECK-LABEL: func.func @scf_for_with_scalable_unit_dims
300+
// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x[1]xf32>
301+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x[1]xf32> to vector<[1]xf32>
302+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
303+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<[1]xf32>
304+
// CHECK: scf.yield %[[SQRT]]
305+
// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<[1]xf32> to vector<1x[1]xf32>
306+
// CHECK: return %[[CASTBACK]]

0 commit comments

Comments
 (0)