Skip to content

Commit 7d4273a

Browse files
committed
[mlir][linalg] Genericize MapOp
1 parent f31bc66 commit 7d4273a

File tree

5 files changed

+29
-30
lines changed

5 files changed

+29
-30
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
293293
// Implement functions necessary for DestinationStyleOpInterface.
294294
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
295295

296-
SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
297-
return getDpsInputOperands();
298-
}
299-
300296
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
301297
if (isDpsInit(opOperand)) return false;
302298
return !getMatchingBlockArgument(opOperand).use_empty();

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
14741474
OpAsmSetValueNameFn setNameFn) {
14751475
for (Value v : getRegionInputArgs())
14761476
setNameFn(v, "in");
1477+
for (Value v : getRegionOutputArgs())
1478+
setNameFn(v, "init");
14771479
}
14781480

14791481
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
14951497

14961498
if (bodyBuild)
14971499
buildGenericRegion(builder, result.location, *result.regions.front(),
1498-
inputs, /*outputs=*/{}, bodyBuild);
1500+
inputs, /*outputs=*/{init}, bodyBuild);
14991501
}
15001502

15011503
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
15021504
const OperationName &payloadOpName,
15031505
const NamedAttrList &payloadOpAttrs,
15041506
ArrayRef<Value> operands,
1505-
bool initFirst = false) {
1507+
bool initFirst = false, bool mapInit = true) {
15061508
OpBuilder b(parser.getContext());
15071509
Region *body = result.addRegion();
15081510
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
15161518
// If initFirst flag is enabled, we consider init as the first position of
15171519
// payload operands.
15181520
if (initFirst) {
1519-
payloadOpOperands.push_back(block.getArguments().back());
1521+
if (mapInit)
1522+
payloadOpOperands.push_back(block.getArguments().back());
15201523
for (const auto &arg : block.getArguments().drop_back())
15211524
payloadOpOperands.push_back(arg);
15221525
} else {
15231526
payloadOpOperands = {block.getArguments().begin(),
1524-
block.getArguments().end()};
1527+
block.getArguments().end() - int(!mapInit)};
15251528
}
15261529

15271530
Operation *payloadOp = b.create(
@@ -1551,12 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15511554
return failure();
15521555

15531556
if (payloadOpName.has_value()) {
1554-
if (!result.operands.empty())
1555-
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1556-
payloadOpAttrs,
1557-
ArrayRef(result.operands).drop_back());
1558-
else
1559-
result.addRegion();
1557+
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1558+
ArrayRef(result.operands), false, false);
15601559
} else {
15611560
SmallVector<OpAsmParser::Argument> regionArgs;
15621561
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1659,7 +1658,7 @@ LogicalResult MapOp::verify() {
16591658
auto blockArgs = bodyBlock->getArguments();
16601659

16611660
// Checks if the number of `inputs` match the arity of the `mapper` region.
1662-
if (getInputs().size() != blockArgs.size())
1661+
if (getInputs().size() + 1 != blockArgs.size())
16631662
return emitOpError() << "expects number of operands to match the arity of "
16641663
"mapper, but got: "
16651664
<< getInputs().size() << " and " << blockArgs.size();

mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ using namespace mlir;
3131
using namespace mlir::linalg;
3232

3333
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
34-
// Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
35-
// trivially generalize a `linalg.map`, as it does not use the output as
36-
// region arguments in the block.
37-
if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
34+
// Bailout if `linalgOp` is already a generic.
35+
if (isa<GenericOp>(linalgOp))
3836
return failure();
3937
// Check if the operation has exactly one region.
4038
if (linalgOp->getNumRegions() != 1) {

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
386386

387387
// -----
388388

389-
// CHECK-LABEL: generalize_linalg_map
390-
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
389+
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
391390
%cst = arith.constant 0.000000e+00 : f32
392-
// CHECK: linalg.map
393-
// CHECK-NOT: linalg.generic
394-
linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
395-
() {
396-
linalg.yield %cst : f32
397-
}
391+
linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
398392
return
399393
}
400394

395+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
396+
397+
// CHECK: @generalize_linalg_map
398+
399+
// CHECK: linalg.generic
400+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
401+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
402+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
403+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
404+
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
405+
// CHECK: linalg.yield %[[ADD]] : f32
406+
401407
// -----
402408

403409
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,

mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
257257
// -----
258258

259259
func.func @map(%lhs: memref<64xf32>,
260-
%rhs: memref<64xf32>, %out: memref<64xf32>) {
260+
%rhs: memref<64xf32>, %init: memref<64xf32>) {
261261
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
262-
outs(%out : memref<64xf32>)
263-
(%in: f32, %in_0: f32) {
262+
outs(%init : memref<64xf32>)
263+
(%in: f32, %in_0: f32, %out: f32) {
264264
%0 = arith.addf %in, %in_0 : f32
265265
linalg.yield %0 : f32
266266
}

0 commit comments

Comments
 (0)