diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 6db51f4b84d11..2a1445fb92fdc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -78,9 +78,12 @@ struct InlineScalarOperands : public OpRewritePattern { for (auto idx : indices) indicesValues.emplace_back( rewriter.create(loc, idx)); - Value extractedValue = rewriter.create( - loc, opOperand->get(), indicesValues); - body->getArgument(idx).replaceAllUsesWith(extractedValue); + Value scalarValue = opOperand->get(); + if (isa(scalarValue.getType())) { + scalarValue = + rewriter.create(loc, scalarValue, indicesValues); + } + body->getArgument(idx).replaceAllUsesWith(scalarValue); body->eraseArgument(idx); } diff --git a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir index 93d5b8779c746..8384b307d2dfb 100644 --- a/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir +++ b/mlir/test/Dialect/Linalg/inline-scalar-operands.mlir @@ -46,3 +46,27 @@ func.func @inline_oned(%arg0: tensor<4xf32>, %scalar: tensor<1xf32>) -> tensor<4 } -> tensor<4xf32> return %1 : tensor<4xf32> } + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> +#map3 = affine_map<(d0) -> ()> + +// CHECK: func @inline_scalar(%[[ARG:.*]]: tensor<4xf32>, %[[SCALAR:.*]]: f32) +func.func @inline_scalar(%arg0: tensor<4xf32>, %scalar: f32) -> tensor<4xf32> { + %0 = tensor.empty() : tensor<4xf32> + // CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], + // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[ARG]] : tensor<4xf32>) + %1 = linalg.generic {indexing_maps = [#map2, #map3, #map2], + iterator_types = ["parallel"]} + ins(%arg0, %scalar : tensor<4xf32>, f32) + outs(%0 : tensor<4xf32>) { + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32) + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + // CHECK: arith.divf %[[IN]], %[[SCALAR]] : f32 + %2 = arith.divf %arg1, %arg2 : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %1 : tensor<4xf32> +}