diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index d00183a1e16a1..baa9856be28e9 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -569,8 +569,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet getPreservedProducerResults(GenericOp producer, - GenericOp consumer, +llvm::SmallDenseSet getPreservedProducerResults(LinalgOp producer, + LinalgOp consumer, OpOperand *fusedOperand); /// Try to peel and canonicalize loop `op` and return the new result. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05fc7cbbb90af..981c38c3a8dec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( // of the fused producer & consumer after the fusion can still compute the // bounds of the op. static bool isOpOperandCanBeDroppedAfterFusedLinalgs( - GenericOp producer, GenericOp consumer, + LinalgOp producer, LinalgOp consumer, ArrayRef opOperandsToIgnore) { SmallVector indexingMaps; - SmallVector ops = {producer, consumer}; + SmallVector ops = {producer, consumer}; for (auto &op : ops) { for (auto &opOperand : op->getOpOperands()) { if (llvm::is_contained(opOperandsToIgnore, &opOperand)) { @@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs( /// * There is a chance that the implementation of the transformation does not /// agree with the result of this method. This function gives a prediction based /// on an optimized fusion. -llvm::SmallDenseSet mlir::linalg::getPreservedProducerResults( - GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) { +llvm::SmallDenseSet +mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer, + OpOperand *fusedOperand) { llvm::SmallDenseSet preservedProducerResults; llvm::SmallVector opOperandsToIgnore; @@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { if (!fusedOperand) return false; - auto producer = fusedOperand->get().getDefiningOp(); - auto consumer = dyn_cast(fusedOperand->getOwner()); + auto producer = fusedOperand->get().getDefiningOp(); + auto consumer = dyn_cast(fusedOperand->getOwner()); // Check producer and consumer are generic ops. if (!producer || !consumer) @@ -219,13 +220,14 @@ static void generateFusedElementwiseOpRegion( RewriterBase &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { - auto producer = cast(fusedOperand->get().getDefiningOp()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(fusedOperand->get().getDefiningOp()); + auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. - Block &producerBlock = producer->getRegion(0).front(); - Block &consumerBlock = consumer->getRegion(0).front(); + + Block &producerBlock = *producer.getBlock(); + Block &consumerBlock = *consumer.getBlock(); OpBuilder::InsertionGuard guard(rewriter); - Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); + Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0)); IRMapping mapper; // 2. Add an index operation for every fused loop dimension and use the @@ -331,7 +333,7 @@ static void generateFusedElementwiseOpRegion( YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues); // Sanity checks. - assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && + assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() && "Ill-formed GenericOp region"); } @@ -341,8 +343,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); auto producerResult = cast(fusedOperand->get()); - auto producer = cast(producerResult.getOwner()); - auto consumer = cast(fusedOperand->getOwner()); + auto producer = cast(producerResult.getOwner()); + auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); @@ -419,10 +421,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, // Generate the fused op. auto fusedOp = GenericOp::create( rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands, - fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), - consumer.getIteratorTypes(), - /*doc=*/nullptr, - /*library_call=*/nullptr); + fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray()); if (!fusedOp.getShapesToLoopsMap()) { // Fused op has invalid indexing maps. Typically this means something is off // in the input, but going ahead here would result in verification errors. @@ -461,14 +460,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, namespace { /// Patterns to fuse a generic op, with the producer of its operands. -class FuseElementwiseOps : public OpRewritePattern { +class FuseElementwiseOps : public OpInterfaceRewritePattern { public: FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : OpInterfaceRewritePattern(context, benefit), controlFn(std::move(fun)) {} - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(LinalgOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand &opOperand : genericOp->getOpOperands()) { @@ -495,7 +494,7 @@ class FuseElementwiseOps : public OpRewritePattern { rewriter.eraseOp(genericOp); return success(); } - return failure(); + return rewriter.notifyMatchFailure(genericOp, "no fusable operands"); } private: diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index bc55c12c02f29..1202a3d198e42 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -1017,9 +1017,75 @@ module { // ----- +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %mapped_65 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.generic + +// ----- + +func.func @map_matmul(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> { + %fill0 = tensor.empty() : tensor<8x10xf32> + %exp = linalg.map {math.exp} ins(%in1 : tensor<8x10xf32>) outs(%fill0: tensor<8x10xf32>) + %fill1 = tensor.empty() : tensor<8x12xf32> + %matmul = linalg.matmul ins(%exp, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32> + return %matmul : tensor<8x12xf32> +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @map_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<10x12xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x12xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[IN0]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXP]], %[[IN1]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT]], %[[MUL]] +// CHECK-NEXT: linalg.yield %[[ADD]] +// CHECK-NOT: linalg.generic + +// ----- + +func.func @matmul_map(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> { + %fill1 = tensor.empty() : tensor<8x12xf32> + %matmul = linalg.matmul ins(%in1, %in2 : tensor<8x10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32> + %exp = linalg.map {math.exp} ins(%matmul : tensor<8x12xf32>) outs(%fill1: tensor<8x12xf32>) + + return %exp : tensor<8x12xf32> +} + +// Should not fuse +// CHECK-LABEL: func @matmul_map +// CHECK-NEXT: tensor.empty +// CHECK-NEXT: linalg.matmul +// CHECK-NEXT: linalg.map +// CHECK-NEXT: return + +// ----- + // In this test we expect the first two linalg.generic operations to be fused into one, but the third one (the matmul) to remain separate. // The reason is that when the pattern is applied the 1st time, the fusion of the first two operations produces a fused operation with -// an additional result and ana dditional output indexing map that is not a permutation / not invertible. +// an additional result and an additional output indexing map that is not a permutation / not invertible. // The fused op will still produce also the original result (and its output indexing map), which is preserved because the new indexing map // is not invertible. Thus the fused op will have 2 results, but only the 2nd one will be used by the following matmul op as an input argument. // When trying to apply the fusion pattern again, the matmul op won't be fused because the operand to fuse was not produced with an invertible indexing map. @@ -1079,4 +1145,4 @@ module { // CHECK-NOT: linalg.generic // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) \ No newline at end of file +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir index bd9977f1410b9..7946b555e7439 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir @@ -59,3 +59,154 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> // CHECK: %[[FUSED_OP:.+]] = linalg.generic // CHECK-SAME: outs(%[[EMPTY]] : // CHECK-NOT: linalg.generic + +// ----- + +func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) + return %sqrt : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map + +// ----- + +func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %init = tensor.empty() : tensor<8xi1> + %initf = tensor.empty() : tensor<8xf32> + %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) + %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>) + (%in0 : f32, %in1 : f32, %out : i1) { + %cmp = arith.cmpf olt, %in0, %in1 : f32 + linalg.yield %cmp : i1 + } + %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) + return %3 : tensor<8xf32> +} + +// CHECK-LABEL: func @map_ops_mixed_types +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]] +// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]] +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]] +// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]] +// CHECK-NEXT: linalg.yield %[[RES]] +// CHECK-NOT: linalg.map + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> + %add = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [#bcast, #identity, #identity] + ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32> + %sqrt = linalg.elementwise + kind=#linalg.elementwise_kind + indexing_maps = [#identity, #identity] + ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32> + return %sqrt : tensor<8x10xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @elementwise_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map + +// ----- + +func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { + %fill = tensor.empty() : tensor<8xf32> + %add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) + (%in0 : f32, %in1 : f32, %out : f32) { + %add = arith.addf %in0, %in1 : f32 + %exp = math.exp %add : f32 + linalg.yield %exp : f32 + } + %sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>) + (%in0 : f32, %in1 : f32, %out : f32) { + %sqrt = math.sqrt %in0 : f32 + %mul = arith.mulf %sqrt, %in1 : f32 + linalg.yield %mul : f32 + } + return %sqrt_mul : tensor<8xf32> +} + +// CHECK-LABEL: func @map_multi_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]] +// CHECK-NEXT: linalg.yield %[[MUL]] +// CHECK-NOT: linalg.map + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#bcast = affine_map<(d0, d1) -> (d0)> +func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> { + %fill = tensor.empty() : tensor<8x10xf32> + %add = linalg.generic + {indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) { + ^bb0(%in0: f32, %in1: f32, %out: f32): + %add = arith.addf %in0, %in1 : f32 + linalg.yield %add : f32 + } -> tensor<8x10xf32> + %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) + return %sqrt : tensor<8x10xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: func @map_genric_ops +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32> +// CHECK: %[[FUSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : +// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] +// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]] +// CHECK-NEXT: linalg.yield %[[SQRT]] +// CHECK-NOT: linalg.map