Skip to content

Conversation

srcarroll
Copy link
Contributor

@srcarroll srcarroll commented Oct 9, 2025

This PR modifies the definition of linalg::MapOp so that it has the same structure of linalg::GenericOp and all other linalg ops. Mainly, it adds an out bbarg for the body of the op. Although the out arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the GenericOp structure. For example, linalg-generalize-named-ops avoided converting linalg.map purely because it didn't have the structure to do so. If GenericOp can have unused bbargs, then ALL linalg ops should be allowed that as well.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 9, 2025

@rengolin this is a followup on our discussion on my other PR #144922 (comment). Just want to make sure this is in the right direction and doesn't conflict with what you or others have in mind for the linalg refactor before I spend time updating all the tests.

update: it wasn't that much work, so just did it anyway

const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
bool initFirst = false) {
bool initFirst = false, bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just the option with least amount of changes but can refactor if necessary


static bool canUseShortForm(Block *body, bool initFirst = false) {
static bool canUseShortForm(Block *body, bool initFirst = false,
bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, this was the option with least amount of changes. will refactor if desired

@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

Changes

This PR modifies the definition of linalg::MapOp so that it has the same structure of linalg::GenericOp and all other linalg ops. Mainly, it adds an out bbarg for the body of the op. Although the out arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the GenericOp structure. For example, linalg-generalize-named-ops avoided converting linalg.map purely because it didn't have the structure to do so. If GenericOp can have unused bbargs, then ALL linalg ops should be allowed that as well.


Full diff: https://github.com/llvm/llvm-project/pull/162742.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (-4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+24-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp (+2-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+2)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+14-8)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+9-9)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+1-1)
  • (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3eecfe6..ecd036d452b27 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
-    SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
-      return getDpsInputOperands();
-    }
-
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
       return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..7ccba6143637e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
                                      OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
     setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
 }
 
 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
 
   if (bodyBuild)
     buildGenericRegion(builder, result.location, *result.regions.front(),
-                       inputs, /*outputs=*/{}, bodyBuild);
+                       inputs, /*outputs=*/{init}, bodyBuild);
 }
 
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
                                  ArrayRef<Value> operands,
-                                 bool initFirst = false) {
+                                 bool initFirst = false, bool mapInit = true) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
   // If initFirst flag is enabled, we consider init as the first position of
   // payload operands.
   if (initFirst) {
-    payloadOpOperands.push_back(block.getArguments().back());
+    if (mapInit)
+      payloadOpOperands.push_back(block.getArguments().back());
     for (const auto &arg : block.getArguments().drop_back())
       payloadOpOperands.push_back(arg);
   } else {
     payloadOpOperands = {block.getArguments().begin(),
-                         block.getArguments().end()};
+                         block.getArguments().end() - int(!mapInit)};
   }
 
   Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   if (payloadOpName.has_value()) {
     if (!result.operands.empty())
       addBodyWithPayloadOp(parser, result, payloadOpName.value(),
-                           payloadOpAttrs,
-                           ArrayRef(result.operands).drop_back());
+                           payloadOpAttrs, ArrayRef(result.operands), false,
+                           false);
     else
       result.addRegion();
   } else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+                            bool mapInit = true) {
+  // `intFirst == true` implies that we want to map init arg
+  if (initFirst && !mapInit)
+    return false;
   // Check if the body can be printed in short form. The following 4 conditions
   // must be satisfied:
 
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
   // 2) The payload op must have the same number of operands as the number of
   //    block arguments.
   if (payload.getNumOperands() == 0 ||
-      payload.getNumOperands() != body->getNumArguments())
+      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
     return false;
 
   // 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
     }
   } else {
     for (const auto &[operand, bbArg] :
-         llvm::zip(payload.getOperands(), body->getArguments())) {
+         llvm::zip(payload.getOperands(),
+                   body->getArguments().drop_back(int(!mapInit)))) {
       if (bbArg != operand)
         return false;
     }
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  bool useShortForm = canUseShortForm(mapper);
+  bool useShortForm =
+      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
   if (useShortForm) {
     printShortForm(p, &mapper->getOperations().front());
   }
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
   auto *bodyBlock = getBody();
   auto blockArgs = bodyBlock->getArguments();
 
-  // Checks if the number of `inputs` match the arity of the `mapper` region.
-  if (getInputs().size() != blockArgs.size())
+  // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+  // region.
+  if (getInputs().size() + 1 != blockArgs.size())
     return emitOpError() << "expects number of operands to match the arity of "
                             "mapper, but got: "
-                         << getInputs().size() << " and " << blockArgs.size();
+                         << getInputs().size() + 1 << " and "
+                         << blockArgs.size();
 
   // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393fd51ed..75bb1757a55f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
-  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
-  // trivially generalize a `linalg.map`, as it does not use the output as
-  // region arguments in the block.
-  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+  // Bailout if `linalgOp` is already a generic.
+  if (isa<GenericOp>(linalgOp))
     return failure();
   // Check if the operation has exactly one region.
   if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e47a3be..c607ece418dff 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
       linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
                             /*init=*/tensorDestination);
   Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+  linalgBody.addArgument(tensorType.getElementType(), loc);
 
   // Create linalg::IndexOps.
   rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
                                           /*inputs=*/ValueRange(),
                                           /*init=*/*tensorAlloc);
     Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+    linalgBody.addArgument(tensorType.getElementType(), loc);
 
     // Create linalg::IndexOps.
     rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98572f47..f4020ede4854e 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
 func.func @recursive_effect(%arg : tensor<1xf32>) {
   %init = arith.constant dense<0.0> : tensor<1xf32>
   %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
-            (%in : f32) {
+            (%in : f32, %out: f32) {
               vector.print %in : f32
               linalg.yield %in : f32
             }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b82228c..dcdd6c8db4b21 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
 
 // -----
 
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
-  // CHECK: linalg.map
-  // CHECK-NOT: linalg.generic
-  linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
-    () {
-      linalg.yield %cst : f32
-    }
+  linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
   return
 }
 
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK:         %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK:         linalg.yield %[[ADD]] : f32
+
 // -----
 
 func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d19d6b91..fabc8e610612d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
    %add = linalg.map
           ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
             linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
 func.func @map_input_mapper_arity_mismatch(
     %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
     -> tensor<64xf32> {
-  // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
       outs(%init:tensor<64xf32>)
-      (%lhs_elem: f64, %rhs_elem: f64) {
+      (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f64
         linalg.yield %0: f64
       }
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
   %add = linalg.map
       ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
       outs(%init:tensor<32xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9616a3e32a064..28d7fdc041766 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d4083af..74928920c695a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
    %add = linalg.map
       outs(%init:tensor<64xf32>)
-      () {
+      (%out: f32) {
         %0 = arith.constant 0.0: f32
         linalg.yield %0: f32
       }
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
 }
 // CHECK-LABEL: func @map_no_inputs
 //       CHECK:   linalg.map outs
-//  CHECK-NEXT:   () {
+//  CHECK-NEXT:   (%[[OUT:.*]]: f32) {
 //  CHECK-NEXT:     arith.constant
 //  CHECK-NEXT:     linalg.yield
 //  CHECK-NEXT:   }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem: f32
             linalg.yield %0: f32
           }
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
    linalg.map
       ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
+      (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
         %0 = arith.addf %lhs_elem, %rhs_elem: f32
         linalg.yield %0: f32
       }
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
    %abs = linalg.map
           ins(%input:tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%input_elem: f32) {
+          (%input_elem: f32, %out: f32) {
             %0 = math.absf %input_elem: f32
             linalg.yield %0: f32
           }
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
    linalg.map
       ins(%input:memref<64xf32>)
       outs(%init:memref<64xf32>)
-      (%input_elem: f32) {
+      (%input_elem: f32, %out: f32) {
         %0 = math.absf %input_elem: f32
         linalg.yield %0: f32
       }
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
+          (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
             %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
             linalg.yield %0: f32
           }
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
   %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
-    (%in_1: f32, %in_2: f32) {
+    (%in_1: f32, %in_2: f32, %out: f32) {
       %1 = arith.maximumf %in_1, %in_2 : f32
       linalg.yield %in_1 : f32
     }
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
 // CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
 // CHECK:         linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) 
 // CHECK-SAME:               outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
 // CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
 // CHECK-NEXT:        linalg.yield %[[IN1]] : f32
 // CHECK-NEXT:    }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 35f520a9f22a8..704ad10130fc8 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
     %arg1: memref<64xf32>, %arg2: memref<64xf32>) {
   linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
              outs(%arg2 : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02564e35..5eb2360a29b8f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 // CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
 // CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK:         () {
+// CHECK:         (%[[INIT:.*]]: f32) {
 // CHECK:           linalg.yield %[[F]] : f32
 // CHECK:         }
 // CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3cbb758b..aa8882d21698c 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 func.func @map(%lhs: memref<64xf32>,
-    %rhs: memref<64xf32>, %out: memref<64xf32>) {
+    %rhs: memref<64xf32>, %init: memref<64xf32>) {
   linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
-             outs(%out : memref<64xf32>)
-    (%in: f32, %in_0: f32) {
+             outs(%init : memref<64xf32>)
+    (%in: f32, %in_0: f32, %out: f32) {
       %0 = arith.addf %in, %in_0 : f32
       linalg.yield %0 : f32
     }

@rengolin
Copy link
Member

If the map operation doesn't use the init, then bufferization will always create a new allocation.

Adding @javedabsar1 who wrote the generalization code and @matthias-springer who knows more about the bufferization part.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 11, 2025

If the map operation doesn't use the init, then bufferization will always create a new allocation.

Adding @javedabsar1 who wrote the generalization code and @matthias-springer who knows more about the bufferization part.

Hmm, good point, but that same issue would exist for ALL linalg ops that don't use init, like linalg.elementwise for example. So I don't see that as a good reason to not make linalg.map work like linalg.generic.

Also, not sure if that's an actual issue. It should allocate because the operation needs to write results to something. I just figured that was part of the semantics of linalg ops. It either reads and writes an already allocated init, or you create a new one to write to. Edit: Eh this isn't actually right. I forgot about contexts like when init is the result of tensor.empty or when it is a function arg and bufferize-function-boundaries is on, etc. But also remember that the init args are always implicitly used because that's where results go.

Furthermore, as you can see, my changes don't affect the behavior of any current tests (except for linalg-generalize-named-ops). So either this isn't tested or the thing you are worrying about is pre-existing anyway. I will double check that my changes don't affect bufferization behavior, but I'm fairly certain they don't.

@srcarroll
Copy link
Contributor Author

srcarroll commented Oct 12, 2025

I confirmed that the following two examples are unaffected

func.func @map(%arg0: tensor<32x1xf32>, %arg1: tensor<32x1xf32>, %arg2: tensor<32x1xf32>) -> tensor<32x1xf32> {
    %0 = linalg.map { arith.subf } ins(%arg0, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%arg2 : tensor<32x1xf32>)
    return %0 : tensor<32x1xf32>
}

func.func @map2(%arg0: tensor<32x1xf32>, %arg1: tensor<32x1xf32>) -> tensor<32x1xf32> {
    %arg2 = tensor.empty() : tensor<32x1xf32>
    %0 = linalg.map { arith.subf } ins(%arg0, %arg2 : tensor<32x1xf32>, tensor<32x1xf32>) outs(%arg2 : tensor<32x1xf32>)
    return %0 : tensor<32x1xf32>
}

using --one-shot-bufferize="bufferize-function-boundaries" produces

  func.func @map(%arg0: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg1: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg2: memref<32x1xf32, strided<[?, ?], offset: ?>>) -> memref<32x1xf32, strided<[?, ?], offset: ?>> {
    linalg.map { arith.subf } ins(%arg0, %arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>, memref<32x1xf32, strided<[?, ?], offset: ?>>) outs(%arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>)
    return %arg2 : memref<32x1xf32, strided<[?, ?], offset: ?>>
  }
  func.func @map2(%arg0: memref<32x1xf32, strided<[?, ?], offset: ?>>, %arg1: memref<32x1xf32, strided<[?, ?], offset: ?>>) -> memref<32x1xf32> {
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x1xf32>
    linalg.map { arith.subf } ins(%arg0, %alloc : memref<32x1xf32, strided<[?, ?], offset: ?>>, memref<32x1xf32>) outs(%alloc : memref<32x1xf32>)
    %cast = memref.cast %alloc : memref<32x1xf32> to memref<32x1xf32, strided<[?, ?], offset: ?>>
    return %alloc : memref<32x1xf32>
  }

Obviously @matthias-springer should confirm or correct, as I have not looked deeply into the bufferization framework. I would doubt that it depends on linalg body block arguments, but rather on the op operands. If that's true, it makes sense that the changes here would not affect it because they are only at body block level. Edit: Moreover, memory effects haven't changed since getGenericEffectsImpl is used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants