Skip to content

Commit b3909f4

Browse files
authored
[MLIR] Drop assumption of a surrounding builtin.func in promoteIfSingleIteration (#116323)
Drop assumption of a surrounding builtin.func in promoteIfSingleIteration. Fixes #116042
1 parent 7a56dc7 commit b3909f4

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,13 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
129129
auto *parentBlock = forOp->getBlock();
130130
if (!iv.use_empty()) {
131131
if (forOp.hasConstantLowerBound()) {
132-
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
133-
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
132+
auto func = forOp->getParentOfType<FunctionOpInterface>();
133+
OpBuilder builder(forOp->getContext());
134+
if (func)
135+
builder.setInsertionPointToStart(&func.getFunctionBody().front());
136+
else
137+
builder.setInsertionPoint(forOp);
138+
auto constOp = builder.create<arith::ConstantIndexOp>(
134139
forOp.getLoc(), forOp.getConstantLowerBound());
135140
iv.replaceAllUsesWith(constOp);
136141
} else {

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
22
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
3+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
34

45
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
56
// Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
@@ -226,3 +227,61 @@ func.func @fuse_higher_dim_nest_into_lower_dim_nest() {
226227
// PRODUCER-CONSUMER: return
227228
return
228229
}
230+
231+
// -----
232+
233+
// Basic test to ensure fusion works inside other func ops like spirv.func.
234+
235+
#map = affine_map<(d0, d1) -> (d0 + d1)>
236+
module {
237+
// SPIRV-LABEL: func @test_avgpool2d_pad_right
238+
spirv.func @test_avgpool2d_pad_right(%arg0: !spirv.array<8192 x f32>) -> !spirv.array<8192 x f32> "None" {
239+
%cst_f32 = spirv.Constant 0.000000e+00 : f32
240+
%0 = builtin.unrealized_conversion_cast %arg0 : !spirv.array<8192 x f32> to tensor<1x32x32x8xf32>
241+
%padded = tensor.pad %0 low[0, 4, 4, 0] high[0, 4, 8193, 0] {
242+
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
243+
tensor.yield %cst_f32 : f32
244+
} : tensor<1x32x32x8xf32> to tensor<1x40x8229x8xf32>
245+
%1 = bufferization.to_memref %padded : memref<1x40x8229x8xf32>
246+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x8xf32>
247+
affine.for %arg1 = 0 to 1 {
248+
affine.for %arg2 = 0 to 32 {
249+
affine.for %arg3 = 0 to 32 {
250+
affine.for %arg4 = 0 to 8 {
251+
affine.for %arg5 = 0 to 1 {
252+
affine.for %arg6 = 0 to 1 {
253+
%4 = affine.apply #map(%arg2, %arg5)
254+
%5 = affine.apply #map(%arg3, %arg6)
255+
%6 = affine.load %1[%arg1, %4, %5, %arg4] : memref<1x40x8229x8xf32>
256+
%7 = affine.load %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32>
257+
%8 = arith.addf %7, %6 : f32
258+
affine.store %8, %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32>
259+
}
260+
}
261+
}
262+
}
263+
}
264+
}
265+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x8xf32>
266+
affine.for %arg1 = 0 to 1 {
267+
affine.for %arg2 = 0 to 32 {
268+
affine.for %arg3 = 0 to 32 {
269+
affine.for %arg4 = 0 to 8 {
270+
%4 = affine.load %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32>
271+
}
272+
}
273+
}
274+
}
275+
// Test fusion.
276+
// SPIRV: affine.for %{{.*}} = 0 to 1 {
277+
// SPIRV-NEXT: affine.for %{{.*}} = 0 to 32 {
278+
// SPIRV-NEXT: affine.for %{{.*}} = 0 to 32 {
279+
// SPIRV-NEXT: affine.for %{{.*}} = 0 to 8 {
280+
// SPIRV-NOT: affine.for %{{.*}}
281+
282+
// SPIRV: ReturnValue
283+
%2 = bufferization.to_tensor %alloc_1 : memref<1x32x32x8xf32>
284+
%3 = builtin.unrealized_conversion_cast %2 : tensor<1x32x32x8xf32> to !spirv.array<8192 x f32>
285+
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
286+
}
287+
}

0 commit comments

Comments
 (0)