Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
return getDpsInputOperands();
}

bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
Expand Down
21 changes: 10 additions & 11 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
for (Value v : getRegionOutputArgs())
setNameFn(v, "init");
}

void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
Expand All @@ -1495,14 +1497,14 @@ void MapOp::build(

if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
inputs, /*outputs=*/{}, bodyBuild);
inputs, /*outputs=*/{init}, bodyBuild);
}

static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
bool initFirst = false) {
bool initFirst = false, bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just the option with least amount of changes but can refactor if necessary

OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
Expand All @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
payloadOpOperands.push_back(block.getArguments().back());
if (mapInit)
payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
block.getArguments().end()};
block.getArguments().end() - int(!mapInit)};
}

Operation *payloadOp = b.create(
Expand Down Expand Up @@ -1551,12 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
payloadOpAttrs,
ArrayRef(result.operands).drop_back());
else
result.addRegion();
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
ArrayRef(result.operands), false, false);
} else {
SmallVector<OpAsmParser::Argument> regionArgs;
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
Expand Down Expand Up @@ -1659,7 +1658,7 @@ LogicalResult MapOp::verify() {
auto blockArgs = bodyBlock->getArguments();

// Checks if the number of `inputs` match the arity of the `mapper` region.
if (getInputs().size() != blockArgs.size())
if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
<< getInputs().size() << " and " << blockArgs.size();
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;

static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
// Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
// trivially generalize a `linalg.map`, as it does not use the output as
// region arguments in the block.
if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
// Bailout if `linalgOp` is already a generic.
if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
Expand Down
22 changes: 14 additions & 8 deletions mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem

// -----

// CHECK-LABEL: generalize_linalg_map
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
%cst = arith.constant 0.000000e+00 : f32
// CHECK: linalg.map
// CHECK-NOT: linalg.generic
linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
() {
linalg.yield %cst : f32
}
linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
return
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

// CHECK: @generalize_linalg_map

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
// CHECK: linalg.yield %[[ADD]] : f32

// -----

func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----

func.func @map(%lhs: memref<64xf32>,
%rhs: memref<64xf32>, %out: memref<64xf32>) {
%rhs: memref<64xf32>, %init: memref<64xf32>) {
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
outs(%out : memref<64xf32>)
(%in: f32, %in_0: f32) {
outs(%init : memref<64xf32>)
(%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
Expand Down
Loading