diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index f3674c3eecfe6..ecd036d452b27 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [ // Implement functions necessary for DestinationStyleOpInterface. MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } - SmallVector getOpOperandsMatchingBBargs() { - return getDpsInputOperands(); - } - bool payloadUsesValueFromOperand(OpOperand * opOperand) { if (isDpsInit(opOperand)) return false; return !getMatchingBlockArgument(opOperand).use_empty(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 59013a23b3e3b..7ccba6143637e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); } void MapOp::getAsmResultNames(function_ref setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build( if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), - inputs, /*outputs=*/{}, bodyBuild); + inputs, /*outputs=*/{init}, bodyBuild); } static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef operands, - bool initFirst = false) { + bool initFirst = false, bool mapInit = true) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, // If initFirst flag is enabled, we consider init as the first position of // payload operands. if (initFirst) { - payloadOpOperands.push_back(block.getArguments().back()); + if (mapInit) + payloadOpOperands.push_back(block.getArguments().back()); for (const auto &arg : block.getArguments().drop_back()) payloadOpOperands.push_back(arg); } else { payloadOpOperands = {block.getArguments().begin(), - block.getArguments().end()}; + block.getArguments().end() - int(!mapInit)}; } Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (payloadOpName.has_value()) { if (!result.operands.empty()) addBodyWithPayloadOp(parser, result, payloadOpName.value(), - payloadOpAttrs, - ArrayRef(result.operands).drop_back()); + payloadOpAttrs, ArrayRef(result.operands), false, + false); else result.addRegion(); } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, + bool mapInit = true) { + // `intFirst == true` implies that we want to map init arg + if (initFirst && !mapInit) + return false; // Check if the body can be printed in short form. The following 4 conditions // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { // 2) The payload op must have the same number of operands as the number of // block arguments. if (payload.getNumOperands() == 0 || - payload.getNumOperands() != body->getNumArguments()) + payload.getNumOperands() != body->getNumArguments() - int(!mapInit)) return false; // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { } } else { for (const auto &[operand, bbArg] : - llvm::zip(payload.getOperands(), body->getArguments())) { + llvm::zip(payload.getOperands(), + body->getArguments().drop_back(int(!mapInit)))) { if (bbArg != operand) return false; } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - bool useShortForm = canUseShortForm(mapper); + bool useShortForm = + canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false); if (useShortForm) { printShortForm(p, &mapper->getOperations().front()); } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() { auto *bodyBlock = getBody(); auto blockArgs = bodyBlock->getArguments(); - // Checks if the number of `inputs` match the arity of the `mapper` region. - if (getInputs().size() != blockArgs.size()) + // Checks if the number of `inputs` + `init` match the arity of the `mapper` + // region. + if (getInputs().size() + 1 != blockArgs.size()) return emitOpError() << "expects number of operands to match the arity of " "mapper, but got: " - << getInputs().size() << " and " << blockArgs.size(); + << getInputs().size() + 1 << " and " + << blockArgs.size(); // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 3e31393fd51ed..75bb1757a55f5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -31,10 +31,8 @@ using namespace mlir; using namespace mlir::linalg; static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { - // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot - // trivially generalize a `linalg.map`, as it does not use the output as - // region arguments in the block. - if (isa(linalgOp) || isa(linalgOp)) + // Bailout if `linalgOp` is already a generic. + if (isa(linalgOp)) return failure(); // Check if the operation has exactly one region. if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bce964e47a3be..c607ece418dff 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*init=*/tensorDestination); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + linalgBody.addArgument(tensorType.getElementType(), loc); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); @@ -1068,6 +1069,7 @@ struct SplatOpInterface /*inputs=*/ValueRange(), /*init=*/*tensorAlloc); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + linalgBody.addArgument(tensorType.getElementType(), loc); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 26d2d98572f47..f4020ede4854e 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref, func.func @recursive_effect(%arg : tensor<1xf32>) { %init = arith.constant dense<0.0> : tensor<1xf32> %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>) - (%in : f32) { + (%in : f32, %out: f32) { vector.print %in : f32 linalg.yield %in : f32 } diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir index ae07b1b82228c..dcdd6c8db4b21 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem // ----- -// CHECK-LABEL: generalize_linalg_map -func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) { +func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) { %cst = arith.constant 0.000000e+00 : f32 - // CHECK: linalg.map - // CHECK-NOT: linalg.generic - linalg.map outs(%arg0 : memref<1x8x8x8xf32>) - () { - linalg.yield %cst : f32 - } + linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>) return } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK: @generalize_linalg_map + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32> +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 + // ----- func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>, diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 40bf4d19d6b91..fabc8e610612d 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands( %add = linalg.map ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}} linalg.yield %0, %0: f32, f32 @@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands( func.func @map_input_mapper_arity_mismatch( %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { - // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}} + // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}} %add = linalg.map ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 linalg.yield %0: f32 } @@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch( %add = linalg.map ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f64, %rhs_elem: f64) { + (%lhs_elem: f64, %rhs_elem: f64, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f64 linalg.yield %0: f64 } @@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch( %add = linalg.map ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>) outs(%init:tensor<32xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 linalg.yield %0: f32 } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index 9616a3e32a064..28d7fdc041766 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 linalg.yield %0: f32 } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 563013d4083af..74928920c695a 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor, func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> { %add = linalg.map outs(%init:tensor<64xf32>) - () { + (%out: f32) { %0 = arith.constant 0.0: f32 linalg.yield %0: f32 } @@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> { } // CHECK-LABEL: func @map_no_inputs // CHECK: linalg.map outs -// CHECK-NEXT: () { +// CHECK-NEXT: (%[[OUT:.*]]: f32) { // CHECK-NEXT: arith.constant // CHECK-NEXT: linalg.yield // CHECK-NEXT: } @@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 linalg.yield %0: f32 } @@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>, linalg.map ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>) outs(%init:memref<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem: f32 linalg.yield %0: f32 } @@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64 %abs = linalg.map ins(%input:tensor<64xf32>) outs(%init:tensor<64xf32>) - (%input_elem: f32) { + (%input_elem: f32, %out: f32) { %0 = math.absf %input_elem: f32 linalg.yield %0: f32 } @@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) { linalg.map ins(%input:memref<64xf32>) outs(%init:memref<64xf32>) - (%input_elem: f32) { + (%input_elem: f32, %out: f32) { %0 = math.absf %input_elem: f32 linalg.yield %0: f32 } @@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) - (%lhs_elem: f32, %rhs_elem: f32) { + (%lhs_elem: f32, %rhs_elem: f32, %out: f32) { %0 = arith.addf %lhs_elem, %rhs_elem fastmath : f32 linalg.yield %0: f32 } @@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> { %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>) - (%in_1: f32, %in_2: f32) { + (%in_1: f32, %in_2: f32, %out: f32) { %1 = arith.maximumf %in_1, %in_2 : f32 linalg.yield %in_1 : f32 } @@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x // CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32> // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) // CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>) -// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) { +// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) { // CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32 // CHECK-NEXT: linalg.yield %[[IN1]] : f32 // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 35f520a9f22a8..704ad10130fc8 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>, %arg1: memref<64xf32>, %arg2: memref<64xf32>) { linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>) outs(%arg2 : memref<64xf32>) - (%in: f32, %in_0: f32) { + (%in: f32, %in_0: f32, %out: f32) { %0 = arith.addf %in, %in_0 : f32 linalg.yield %0 : f32 } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 296ca02564e35..5eb2360a29b8f 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor, %g: tensor // CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor) -// CHECK: () { +// CHECK: (%[[INIT:.*]]: f32) { // CHECK: linalg.yield %[[F]] : f32 // CHECK: } // CHECK: return %[[MAPPED]] : tensor diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index 8cbee3cbb758b..aa8882d21698c 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} { // ----- func.func @map(%lhs: memref<64xf32>, - %rhs: memref<64xf32>, %out: memref<64xf32>) { + %rhs: memref<64xf32>, %init: memref<64xf32>) { linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>) - outs(%out : memref<64xf32>) - (%in: f32, %in_0: f32) { + outs(%init : memref<64xf32>) + (%in: f32, %in_0: f32, %out: f32) { %0 = arith.addf %in, %in_0 : f32 linalg.yield %0 : f32 }