diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index d6fc4ed07bfab..827adfe892969 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -129,8 +129,13 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { auto *parentBlock = forOp->getBlock(); if (!iv.use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(forOp->getParentOfType().getBody()); - auto constOp = topBuilder.create( + auto func = forOp->getParentOfType(); + OpBuilder builder(forOp->getContext()); + if (func) + builder.setInsertionPointToStart(&func.getFunctionBody().front()); + else + builder.setInsertionPoint(forOp); + auto constOp = builder.create( forOp.getLoc(), forOp.getConstantLowerBound()); iv.replaceAllUsesWith(constOp); } else { diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 3fc31ad0d77b8..f46ad0f5e4c23 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir @@ -226,3 +227,61 @@ func.func @fuse_higher_dim_nest_into_lower_dim_nest() { // PRODUCER-CONSUMER: return return } + +// ----- + +// Basic test to ensure fusion works inside other func ops like spirv.func. + +#map = affine_map<(d0, d1) -> (d0 + d1)> +module { + // SPIRV-LABEL: func @test_avgpool2d_pad_right + spirv.func @test_avgpool2d_pad_right(%arg0: !spirv.array<8192 x f32>) -> !spirv.array<8192 x f32> "None" { + %cst_f32 = spirv.Constant 0.000000e+00 : f32 + %0 = builtin.unrealized_conversion_cast %arg0 : !spirv.array<8192 x f32> to tensor<1x32x32x8xf32> + %padded = tensor.pad %0 low[0, 4, 4, 0] high[0, 4, 8193, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst_f32 : f32 + } : tensor<1x32x32x8xf32> to tensor<1x40x8229x8xf32> + %1 = bufferization.to_memref %padded : memref<1x40x8229x8xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x8xf32> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 32 { + affine.for %arg3 = 0 to 32 { + affine.for %arg4 = 0 to 8 { + affine.for %arg5 = 0 to 1 { + affine.for %arg6 = 0 to 1 { + %4 = affine.apply #map(%arg2, %arg5) + %5 = affine.apply #map(%arg3, %arg6) + %6 = affine.load %1[%arg1, %4, %5, %arg4] : memref<1x40x8229x8xf32> + %7 = affine.load %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32> + %8 = arith.addf %7, %6 : f32 + affine.store %8, %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32> + } + } + } + } + } + } + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x8xf32> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 32 { + affine.for %arg3 = 0 to 32 { + affine.for %arg4 = 0 to 8 { + %4 = affine.load %alloc_0[%arg1, %arg2, %arg3, %arg4] : memref<1x32x32x8xf32> + } + } + } + } + // Test fusion. + // SPIRV: affine.for %{{.*}} = 0 to 1 { + // SPIRV-NEXT: affine.for %{{.*}} = 0 to 32 { + // SPIRV-NEXT: affine.for %{{.*}} = 0 to 32 { + // SPIRV-NEXT: affine.for %{{.*}} = 0 to 8 { + // SPIRV-NOT: affine.for %{{.*}} + + // SPIRV: ReturnValue + %2 = bufferization.to_tensor %alloc_1 : memref<1x32x32x8xf32> + %3 = builtin.unrealized_conversion_cast %2 : tensor<1x32x32x8xf32> to !spirv.array<8192 x f32> + spirv.ReturnValue %3 : !spirv.array<8192 x f32> + } +}