Skip to content

Commit 02e5678

Browse files
committed
[CodeGen] Improve scf.for bufferization and make hoisting allocation work
The revision enables `allowReturnAllocsFromLoops` in bufferization, which matches the upstream behavior. Otherwise, it can trigger an error like: ``` error: Yield operand #1 is not equivalent to the corresponding iter bbArg ``` In this context, a `memref.alloca` can be created inside the loop and the dynamic size can be queried from iter_arg. The ValueBoundsConstraintSet check does not support the analysis, because the runtime dimension values can still differ. E.g., ```mlir %result = scf.for ... iter_args(%iter = %init) -> (memref<?xf32>) { %new_buf = memref.alloca(%some_other_size) : memref<?xf32> scf.yield %new_buf : memref<?xf32> // same type, different runtime size } ``` It is weird, but it is allowed. Thus, we need to handle such case in `hoistOneStaticallyBoundAllocation`. The revision verifies the dimension is preserved, via: 1. The yield operand (after walking through cast/subview) is the iter_arg. 2. The yield operand traces to an alloca whose shape matches the iter_arg and whose dynamic size at `dimIndex` is `memref.dim` of the iter_arg. 3. The yield operand is a scf.for result whose init arg is the iter_arg and the inner loop also preserves the dimension (recursive). Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 5ee0652 commit 02e5678

File tree

4 files changed

+328
-15
lines changed

4 files changed

+328
-15
lines changed

compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ static IREEOneShotBufferizationOptions getBufferizationOptions() {
151151
// as is and insert bufferization.to_buffer to convert the tensor to memref.
152152
options.opFilter.denyOperation<arith::ConstantOp>();
153153

154+
// Allow returning allocs from loops. This is needed for patterns like online
155+
// attention where scf.for yield operands cannot be buffer-equivalent to their
156+
// corresponding iter bbArgs (e.g., the new max value is computed from both
157+
// the old max and new data). This matches MLIR upstream's
158+
// EmptyTensorElimination pass behavior.
159+
options.allowReturnAllocsFromLoops = true;
160+
154161
// This type converter converts tensor types to memref types when no exact
155162
// memref type can be inferred from the context.
156163
options.unknownTypeConverterFn = [](TensorType tensorType,

compiler/src/iree/compiler/Codegen/Common/test/hoist_statically_bound_allocations.mlir

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,140 @@ func.func @nested_op_scalable_alloc_linalg_use(%arg0 : index) {
249249
// CHECK-UNBOUNDED-VSCALE-LABEL: func @nested_op_scalable_alloc_linalg_use(
250250
// CHECK-UNBOUNDED-VSCALE: scf.for
251251
// CHECK-UNBOUNDED-VSCALE: memref.alloc
252+
253+
// -----
254+
255+
// The yield is the iter_arg itself — dimension trivially preserved. The
256+
// alloca's size comes from memref.dim of the iter_arg, and
257+
// computeAllocationBound traces through the loop to the init value.
258+
func.func @hoist_alloca_yield_iter_arg(%arg0 : index) {
259+
%c0 = arith.constant 0 : index
260+
%c1 = arith.constant 1 : index
261+
%cst = arith.constant 0.000000e+00 : f32
262+
%static = memref.alloca() : memref<1x4xf32>
263+
%sv = memref.subview %static[0, 0][1, %arg0][1, 1]
264+
: memref<1x4xf32> to memref<1x?xf32, strided<[4, 1]>>
265+
%init = memref.cast %sv
266+
: memref<1x?xf32, strided<[4, 1]>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
267+
%result = scf.for %i = %c0 to %arg0 step %c1
268+
iter_args(%iter = %init) -> (memref<1x?xf32, strided<[?, ?], offset: ?>>) {
269+
%dim = memref.dim %iter, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>>
270+
%alloca = memref.alloca(%dim) : memref<1x?xf32>
271+
linalg.fill ins(%cst : f32) outs(%alloca : memref<1x?xf32>)
272+
scf.yield %iter : memref<1x?xf32, strided<[?, ?], offset: ?>>
273+
}
274+
return
275+
}
276+
// CHECK-LABEL: func @hoist_alloca_yield_iter_arg(
277+
// CHECK: %[[HOISTED:.+]] = memref.alloca() : memref<1x4xf32>
278+
// CHECK: scf.for
279+
// CHECK-NOT: memref.alloca(
280+
// CHECK: %[[DIM:.+]] = memref.dim
281+
// CHECK: %[[SV:.+]] = memref.subview %[[HOISTED]][0, 0] [1, %[[DIM]]] [1, 1]
282+
// CHECK: linalg.fill
283+
284+
// -----
285+
286+
// The yield traces through cast and subview to an alloca whose dynamic size at
287+
// dimIndex is memref.dim of the iter_arg (self-referential). This exercises the
288+
// cast/subview walk in the function.
289+
func.func @hoist_alloca_yield_self_ref_subview(%arg0 : index) {
290+
%c0 = arith.constant 0 : index
291+
%c1 = arith.constant 1 : index
292+
%cst = arith.constant 0.000000e+00 : f32
293+
%static = memref.alloca() : memref<1x4xf32>
294+
%sv = memref.subview %static[0, 0][1, %arg0][1, 1]
295+
: memref<1x4xf32> to memref<1x?xf32, strided<[4, 1]>>
296+
%init = memref.cast %sv
297+
: memref<1x?xf32, strided<[4, 1]>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
298+
%result = scf.for %i = %c0 to %arg0 step %c1
299+
iter_args(%iter = %init) -> (memref<1x?xf32, strided<[?, ?], offset: ?>>) {
300+
%dim = memref.dim %iter, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>>
301+
%alloca = memref.alloca(%dim) : memref<1x?xf32>
302+
linalg.fill ins(%cst : f32) outs(%alloca : memref<1x?xf32>)
303+
%val = memref.load %alloca[%c0, %c0] : memref<1x?xf32>
304+
// Yield traces: cast → subview → alloca (exercises the walk loop).
305+
%sub = memref.subview %alloca[0, 0][1, %dim][1, 1]
306+
: memref<1x?xf32> to memref<1x?xf32, strided<[?, 1]>>
307+
%cast = memref.cast %sub
308+
: memref<1x?xf32, strided<[?, 1]>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
309+
scf.yield %cast : memref<1x?xf32, strided<[?, ?], offset: ?>>
310+
}
311+
return
312+
}
313+
// CHECK-LABEL: func @hoist_alloca_yield_self_ref_subview(
314+
// CHECK: %[[HOISTED:.+]] = memref.alloca() : memref<1x4xf32>
315+
// CHECK: scf.for
316+
// CHECK-NOT: memref.alloca(
317+
// CHECK: %[[DIM:.+]] = memref.dim
318+
// CHECK: %[[SV:.+]] = memref.subview %[[HOISTED]][0, 0] [1, %[[DIM]]] [1, 1]
319+
// CHECK: linalg.fill
320+
// CHECK: memref.load
321+
322+
// -----
323+
324+
// The yield is an inner scf.for result. The inner loop preserves the dimension
325+
// via the case that yield is iter_arg, and the recursive check verifies the
326+
// inner loop.
327+
func.func @hoist_alloca_yield_nested_loop(%arg0 : index) {
328+
%c0 = arith.constant 0 : index
329+
%c1 = arith.constant 1 : index
330+
%cst = arith.constant 0.000000e+00 : f32
331+
%static = memref.alloca() : memref<1x4xf32>
332+
%sv = memref.subview %static[0, 0][1, %arg0][1, 1]
333+
: memref<1x4xf32> to memref<1x?xf32, strided<[4, 1]>>
334+
%init = memref.cast %sv
335+
: memref<1x?xf32, strided<[4, 1]>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
336+
%result = scf.for %i = %c0 to %arg0 step %c1
337+
iter_args(%outer_iter = %init) -> (memref<1x?xf32, strided<[?, ?], offset: ?>>) {
338+
%inner = scf.for %j = %c0 to %arg0 step %c1
339+
iter_args(%inner_iter = %outer_iter) -> (memref<1x?xf32, strided<[?, ?], offset: ?>>) {
340+
%dim = memref.dim %inner_iter, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>>
341+
%alloca = memref.alloca(%dim) : memref<1x?xf32>
342+
linalg.fill ins(%cst : f32) outs(%alloca : memref<1x?xf32>)
343+
scf.yield %inner_iter : memref<1x?xf32, strided<[?, ?], offset: ?>>
344+
}
345+
scf.yield %inner : memref<1x?xf32, strided<[?, ?], offset: ?>>
346+
}
347+
return
348+
}
349+
// CHECK-LABEL: func @hoist_alloca_yield_nested_loop(
350+
// CHECK: %[[HOISTED:.+]] = memref.alloca() : memref<1x4xf32>
351+
// CHECK: scf.for
352+
// CHECK: scf.for
353+
// CHECK-NOT: memref.alloca(
354+
// CHECK: %[[DIM:.+]] = memref.dim
355+
// CHECK: %[[SV:.+]] = memref.subview %[[HOISTED]][0, 0] [1, %[[DIM]]] [1, 1]
356+
// CHECK: linalg.fill
357+
358+
// -----
359+
360+
// Negative test: the yield uses an alloca sized by a different value (%arg1)
361+
// rather than the iter_arg's dimension, so the dimension is not preserved
362+
// across iterations. The alloca should NOT be hoisted.
363+
func.func @no_hoist_alloca_yield_dim_not_preserved(%arg0 : index, %arg1 : index) {
364+
%c0 = arith.constant 0 : index
365+
%c1 = arith.constant 1 : index
366+
%cst = arith.constant 0.000000e+00 : f32
367+
%static = memref.alloca() : memref<1x4xf32>
368+
%sv = memref.subview %static[0, 0][1, %arg0][1, 1]
369+
: memref<1x4xf32> to memref<1x?xf32, strided<[4, 1]>>
370+
%init = memref.cast %sv
371+
: memref<1x?xf32, strided<[4, 1]>> to memref<1x?xf32, strided<[?, ?], offset: ?>>
372+
%result = scf.for %iv = %c0 to %arg0 step %c1
373+
iter_args(%iter = %init) -> (memref<1x?xf32, strided<[?, ?], offset: ?>>) {
374+
%dim = memref.dim %iter, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>>
375+
%inner = memref.alloca(%dim) : memref<1x?xf32>
376+
linalg.fill ins(%cst : f32) outs(%inner : memref<1x?xf32>)
377+
// Yield an alloca with a different size — dimension not preserved.
378+
%other = memref.alloca(%arg1) : memref<1x?xf32>
379+
%cast = memref.cast %other
380+
: memref<1x?xf32> to memref<1x?xf32, strided<[?, ?], offset: ?>>
381+
scf.yield %cast : memref<1x?xf32, strided<[?, ?], offset: ?>>
382+
}
383+
return
384+
}
385+
// CHECK-LABEL: func @no_hoist_alloca_yield_dim_not_preserved(
386+
// CHECK: scf.for
387+
// CHECK: %[[DIM:.+]] = memref.dim
388+
// CHECK: memref.alloca(%[[DIM]]) : memref<1x?xf32>

compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3220,3 +3220,45 @@ func.func @drop_fusion_barrier() -> memref<6xf32> {
32203220
// CHECK-LABEL: func.func @drop_fusion_barrier
32213221
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xf32>
32223222
// CHECK: return %[[ALLOC]]
3223+
3224+
// -----
3225+
3226+
// Regression test for https://github.com/iree-org/iree/issues/16956.
3227+
// The yield operand %new is not buffer-equivalent to the iter bbArg %arg
3228+
// because %arg is read after %new is computed. With allowReturnAllocsFromLoops,
3229+
// bufferization allocates a new buffer inside the loop instead of failing.
3230+
func.func @bufferize_non_equivalent_scf_yield() {
3231+
%c0 = arith.constant 0 : index
3232+
%c1 = arith.constant 1 : index
3233+
%c4 = arith.constant 4 : index
3234+
%cst = arith.constant 0.000000e+00 : f32
3235+
%empty = tensor.empty() : tensor<4xf32>
3236+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4xf32>) -> tensor<4xf32>
3237+
%0 = scf.for %iv = %c0 to %c4 step %c1
3238+
iter_args(%arg = %fill) -> (tensor<4xf32>) {
3239+
%new = linalg.generic {
3240+
indexing_maps = [affine_map<(d0) -> (d0)>],
3241+
iterator_types = ["parallel"]}
3242+
outs(%arg : tensor<4xf32>) {
3243+
^bb0(%out: f32):
3244+
%v = arith.addf %out, %cst : f32
3245+
linalg.yield %v : f32
3246+
} -> tensor<4xf32>
3247+
// Reading %arg after %new forces non-equivalence.
3248+
%use = linalg.generic {
3249+
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
3250+
iterator_types = ["parallel"]}
3251+
ins(%new : tensor<4xf32>) outs(%arg : tensor<4xf32>) {
3252+
^bb0(%in: f32, %out: f32):
3253+
%s = arith.subf %out, %in : f32
3254+
linalg.yield %s : f32
3255+
} -> tensor<4xf32>
3256+
scf.yield %new : tensor<4xf32>
3257+
}
3258+
return
3259+
}
3260+
// CHECK-LABEL: func.func @bufferize_non_equivalent_scf_yield
3261+
// CHECK: %[[INIT:.+]] = memref.alloc() : memref<4xf32>
3262+
// CHECK: scf.for {{.*}} iter_args(%[[ARG:.+]] = %[[INIT]])
3263+
// CHECK: %[[NEW:.+]] = memref.alloc() : memref<4xf32>
3264+
// CHECK: scf.yield %[[NEW]]

compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp

Lines changed: 142 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
3131
#include "mlir/Dialect/Arith/IR/Arith.h"
3232
#include "mlir/Dialect/MemRef/IR/MemRef.h"
33+
#include "mlir/Dialect/SCF/IR/SCF.h"
3334
#include "mlir/Dialect/Tensor/IR/Tensor.h"
3435
#include "mlir/Dialect/Utils/StaticValueUtils.h"
3536
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
@@ -118,6 +119,81 @@ cloneOffsetsSizesAndStrides(OpBuilder &builder,
118119
loadOp.getMixedSizes(), loadOp.getMixedStrides(), loadOp.getSourceDims());
119120
}
120121

122+
/// Returns true if the yield operand for the `argIdx`-th iter_arg of `forOp`
123+
/// preserves the dimension at `dimIndex`. This is needed to verify that
124+
/// computing an allocation bound from the init value is sound — if the yield
125+
/// could produce a larger dimension, the init-derived bound would be too small.
126+
/// Verify the dimension is preserved, via:
127+
///
128+
/// (1) The yield operand (after walking through cast/subview) is the iter_arg.
129+
/// (2) The yield operand traces to an alloca whose shape matches the iter_arg
130+
/// and whose dynamic size at `dimIndex` is `memref.dim` of the iter_arg.
131+
/// (3) The yield operand is a scf.for result whose init arg is the iter_arg
132+
/// and the inner loop also preserves the dimension (recursive).
133+
///
134+
/// Note: This may revisit inner loops when called at each nesting level during
135+
/// the source walk in computeAllocationBound. Caching would help if the nesting
136+
/// depth were large, but in practice it is bounded by the tensor rank.
137+
static bool isYieldDimPreserved(scf::ForOp forOp, unsigned argIdx,
138+
int64_t dimIndex) {
139+
BlockArgument iterArg = forOp.getRegionIterArg(argIdx);
140+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
141+
Value yieldVal = yieldOp.getOperand(argIdx);
142+
143+
// Walk through cast/subview to find the underlying source.
144+
while (true) {
145+
if (auto castOp = yieldVal.getDefiningOp<memref::CastOp>()) {
146+
yieldVal = castOp.getSource();
147+
continue;
148+
}
149+
if (auto subviewOp = yieldVal.getDefiningOp<memref::SubViewOp>()) {
150+
yieldVal = subviewOp.getSource();
151+
continue;
152+
}
153+
break;
154+
}
155+
156+
// Case 1: Yield is the iter_arg itself — dimension trivially invariant.
157+
if (yieldVal == iterArg) {
158+
return true;
159+
}
160+
161+
// Case 2: Yield traces to an alloca whose dynamic size at dimIndex comes
162+
// from memref.dim of the same iter_arg (self-referential invariance).
163+
if (auto allocaOp = yieldVal.getDefiningOp<memref::AllocaOp>()) {
164+
MemRefType allocType = allocaOp.getType();
165+
auto iterArgType = cast<MemRefType>(iterArg.getType());
166+
// Shape comparison ensures same rank and same static/dynamic pattern,
167+
// so we can directly index the dynamic sizes by counting dynamic dims
168+
// before dimIndex.
169+
if (allocType.getShape() != iterArgType.getShape()) {
170+
return false;
171+
}
172+
unsigned dynIdx = llvm::count_if(allocType.getShape().take_front(dimIndex),
173+
ShapedType::isDynamic);
174+
auto dimOp =
175+
allocaOp.getDynamicSizes()[dynIdx].getDefiningOp<memref::DimOp>();
176+
if (!dimOp || dimOp.getSource() != iterArg) {
177+
return false;
178+
}
179+
auto idx = dimOp.getConstantIndex();
180+
return idx && *idx == dimIndex;
181+
}
182+
183+
// Case 3: Yield is a scf.for result whose init arg at the same index is
184+
// the iter_arg, and the inner loop also preserves the dimension.
185+
if (auto result = dyn_cast<OpResult>(yieldVal)) {
186+
if (auto innerFor = dyn_cast<scf::ForOp>(result.getOwner())) {
187+
unsigned resultIdx = result.getResultNumber();
188+
if (innerFor.getInitArgs()[resultIdx] == iterArg) {
189+
return isYieldDimPreserved(innerFor, resultIdx, dimIndex);
190+
}
191+
}
192+
}
193+
194+
return false;
195+
}
196+
121197
template <typename AllocLikeOpType>
122198
std::optional<Value> hoistOneStaticallyBoundAllocation(
123199
mlir::FunctionOpInterface funcOp, OpBuilder &builder, Location loc,
@@ -156,30 +232,81 @@ std::optional<Value> hoistOneStaticallyBoundAllocation(
156232
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
157233
value, std::nullopt, vscaleRange->vscaleMin,
158234
vscaleRange->vscaleMax, presburger::BoundType::UB);
159-
if (failed(ub)) {
160-
return failure();
161-
}
235+
if (succeeded(ub)) {
236+
if (ub->map.isSingleConstant()) {
237+
auto constantBound = ub->map.getSingleConstantResult();
238+
return OpFoldResult(builder.getIndexAttr(constantBound));
239+
}
162240

163-
if (ub->map.isSingleConstant()) {
164-
auto constantBound = ub->map.getSingleConstantResult();
165-
return OpFoldResult(builder.getIndexAttr(constantBound));
241+
if (!vscale) {
242+
vscale = vector::VectorScaleOp::create(builder, loc);
243+
}
244+
return affine::materializeComputedBound(
245+
builder, loc, ub->map, {std::make_pair(vscale, std::nullopt)});
166246
}
247+
} else {
248+
// Non-scalable target: Assume everything is fixed-size.
249+
auto ub = ValueBoundsConstraintSet::computeConstantBound(
250+
presburger::BoundType::UB, {value, std::nullopt},
251+
/*stopCondition=*/nullptr,
252+
/*closedUB=*/true);
253+
if (succeeded(ub)) {
254+
return OpFoldResult(builder.getIndexAttr(*ub));
255+
}
256+
}
167257

168-
if (!vscale) {
169-
vscale = vector::VectorScaleOp::create(builder, loc);
258+
// Special case for memref.dim. If the value is a memref.dim on a loop
259+
// iter_arg, try computing the bound using the init value's dimension. This
260+
// handles cases where bufferization creates loop-internal allocas with
261+
// sizes derived from iter_arg dimensions (e.g., from issue #16956,
262+
// allowReturnAllocsFromLoops, etc). The value bounds analysis cannot trace
263+
// through scf.for iter_args, so we walk up to the outermost init value and
264+
// compute the bound from there.
265+
auto dimOp = value.getDefiningOp<memref::DimOp>();
266+
if (!dimOp) {
267+
return failure();
268+
}
269+
std::optional<int64_t> constIndex = dimOp.getConstantIndex();
270+
if (!constIndex) {
271+
return failure();
272+
}
273+
274+
// Walk up through nested loop iter_args, casts, and subviews to find a
275+
// value whose dimension bound can be computed.
276+
Value source = dimOp.getSource();
277+
while (true) {
278+
if (auto blockArg = dyn_cast<BlockArgument>(source)) {
279+
auto forOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
280+
if (!forOp) {
281+
break;
282+
}
283+
unsigned argIdx = blockArg.getArgNumber() - forOp.getNumInductionVars();
284+
if (isYieldDimPreserved(forOp, argIdx, *constIndex)) {
285+
source = forOp.getInitArgs()[argIdx];
286+
continue;
287+
}
170288
}
171-
return affine::materializeComputedBound(
172-
builder, loc, ub->map, {std::make_pair(vscale, std::nullopt)});
289+
if (auto castOp = source.getDefiningOp<memref::CastOp>()) {
290+
source = castOp.getSource();
291+
continue;
292+
}
293+
if (auto subviewOp = source.getDefiningOp<memref::SubViewOp>()) {
294+
source = subviewOp.getSource();
295+
continue;
296+
}
297+
break;
298+
}
299+
if (source == dimOp.getSource()) {
300+
return failure();
173301
}
174-
// Non-scalable target: Assume everything is fixed-size.
302+
175303
auto ub = ValueBoundsConstraintSet::computeConstantBound(
176-
presburger::BoundType::UB, {value, std::nullopt},
304+
presburger::BoundType::UB, {source, *constIndex},
177305
/*stopCondition=*/nullptr,
178306
/*closedUB=*/true);
179307
if (failed(ub)) {
180308
return failure();
181309
}
182-
183310
return OpFoldResult(builder.getIndexAttr(*ub));
184311
};
185312

@@ -264,8 +391,8 @@ std::optional<Value> hoistOneStaticallyBoundAllocation(
264391
/// non-trivial because of compatibility between types of different SSA values.
265392
static bool isUseReplaceableWithSubview(OpOperand &use) {
266393
Operation *user = use.getOwner();
267-
return isa<linalg::LinalgOp, memref::DeallocOp, memref::StoreOp,
268-
memref::SubViewOp>(user);
394+
return isa<linalg::LinalgOp, memref::CastOp, memref::DeallocOp,
395+
memref::LoadOp, memref::StoreOp, memref::SubViewOp>(user);
269396
}
270397

271398
template <typename AllocLikeOpType>

0 commit comments

Comments
 (0)