Skip to content

Commit f5e175f

Browse files
authored
[mlir][linalg] Genericize MapOp (#162742)
This PR modifies the definition of `linalg::MapOp` so that it has the same structure of `linalg::GenericOp` and all other linalg ops. Mainly, it adds an `out` bbarg for the body of the op. Although the `out` arg is never used in the body, there doesn't seem to be much benefit in specializing the op to exclude it. In fact it only makes things more complicated because it doesn't align with the `GenericOp` structure. For example, `linalg-generalize-named-ops` avoided converting `linalg.map` purely because it didn't have the structure to do so. Moreover, although some fusion patterns are applied explicitly to `GenericOp`, we can change them to be applied to the base `LinalgOp` which will enable fusion for any fusion-compatible linalg op, but that requires the op having a generic structure. So these changes will enable us to use existing generic transformation patterns on `MapOp` that weren't possible before. They can either be applied to `MapOp` directly or applied after converting to `GenericOp`.
1 parent ba5cde7 commit f5e175f

File tree

12 files changed

+63
-50
lines changed

12 files changed

+63
-50
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
293293
// Implement functions necessary for DestinationStyleOpInterface.
294294
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
295295

296-
SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
297-
return getDpsInputOperands();
298-
}
299-
300296
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
301297
if (isDpsInit(opOperand)) return false;
302298
return !getMatchingBlockArgument(opOperand).use_empty();

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
14741474
OpAsmSetValueNameFn setNameFn) {
14751475
for (Value v : getRegionInputArgs())
14761476
setNameFn(v, "in");
1477+
for (Value v : getRegionOutputArgs())
1478+
setNameFn(v, "init");
14771479
}
14781480

14791481
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
14951497

14961498
if (bodyBuild)
14971499
buildGenericRegion(builder, result.location, *result.regions.front(),
1498-
inputs, /*outputs=*/{}, bodyBuild);
1500+
inputs, /*outputs=*/{init}, bodyBuild);
14991501
}
15001502

15011503
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
15021504
const OperationName &payloadOpName,
15031505
const NamedAttrList &payloadOpAttrs,
15041506
ArrayRef<Value> operands,
1505-
bool initFirst = false) {
1507+
bool initFirst = false, bool mapInit = true) {
15061508
OpBuilder b(parser.getContext());
15071509
Region *body = result.addRegion();
15081510
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
15161518
// If initFirst flag is enabled, we consider init as the first position of
15171519
// payload operands.
15181520
if (initFirst) {
1519-
payloadOpOperands.push_back(block.getArguments().back());
1521+
if (mapInit)
1522+
payloadOpOperands.push_back(block.getArguments().back());
15201523
for (const auto &arg : block.getArguments().drop_back())
15211524
payloadOpOperands.push_back(arg);
15221525
} else {
15231526
payloadOpOperands = {block.getArguments().begin(),
1524-
block.getArguments().end()};
1527+
block.getArguments().end() - int(!mapInit)};
15251528
}
15261529

15271530
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15531556
if (payloadOpName.has_value()) {
15541557
if (!result.operands.empty())
15551558
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1556-
payloadOpAttrs,
1557-
ArrayRef(result.operands).drop_back());
1559+
payloadOpAttrs, ArrayRef(result.operands), false,
1560+
false);
15581561
else
15591562
result.addRegion();
15601563
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15701573
return success();
15711574
}
15721575

1573-
static bool canUseShortForm(Block *body, bool initFirst = false) {
1576+
static bool canUseShortForm(Block *body, bool initFirst = false,
1577+
bool mapInit = true) {
1578+
// `intFirst == true` implies that we want to map init arg
1579+
if (initFirst && !mapInit)
1580+
return false;
15741581
// Check if the body can be printed in short form. The following 4 conditions
15751582
// must be satisfied:
15761583

@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
15821589
// 2) The payload op must have the same number of operands as the number of
15831590
// block arguments.
15841591
if (payload.getNumOperands() == 0 ||
1585-
payload.getNumOperands() != body->getNumArguments())
1592+
payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
15861593
return false;
15871594

15881595
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
16001607
}
16011608
} else {
16021609
for (const auto &[operand, bbArg] :
1603-
llvm::zip(payload.getOperands(), body->getArguments())) {
1610+
llvm::zip(payload.getOperands(),
1611+
body->getArguments().drop_back(int(!mapInit)))) {
16041612
if (bbArg != operand)
16051613
return false;
16061614
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16321640

16331641
void MapOp::print(OpAsmPrinter &p) {
16341642
Block *mapper = getBody();
1635-
bool useShortForm = canUseShortForm(mapper);
1643+
bool useShortForm =
1644+
canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
16361645
if (useShortForm) {
16371646
printShortForm(p, &mapper->getOperations().front());
16381647
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
16581667
auto *bodyBlock = getBody();
16591668
auto blockArgs = bodyBlock->getArguments();
16601669

1661-
// Checks if the number of `inputs` match the arity of the `mapper` region.
1662-
if (getInputs().size() != blockArgs.size())
1670+
// Checks if the number of `inputs` + `init` match the arity of the `mapper`
1671+
// region.
1672+
if (getInputs().size() + 1 != blockArgs.size())
16631673
return emitOpError() << "expects number of operands to match the arity of "
16641674
"mapper, but got: "
1665-
<< getInputs().size() << " and " << blockArgs.size();
1675+
<< getInputs().size() + 1 << " and "
1676+
<< blockArgs.size();
16661677

16671678
// The parameters of mapper should all match the element type of inputs.
16681679
for (const auto &[bbArgType, inputArg] :

mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ using namespace mlir;
3131
using namespace mlir::linalg;
3232

3333
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
34-
// Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
35-
// trivially generalize a `linalg.map`, as it does not use the output as
36-
// region arguments in the block.
37-
if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
34+
// Bailout if `linalgOp` is already a generic.
35+
if (isa<GenericOp>(linalgOp))
3836
return failure();
3937
// Check if the operation has exactly one region.
4038
if (linalgOp->getNumRegions() != 1) {

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
579579
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
580580
/*init=*/tensorDestination);
581581
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
582+
linalgBody.addArgument(tensorType.getElementType(), loc);
582583

583584
// Create linalg::IndexOps.
584585
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
10681069
/*inputs=*/ValueRange(),
10691070
/*init=*/*tensorAlloc);
10701071
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
1072+
linalgBody.addArgument(tensorType.getElementType(), loc);
10711073

10721074
// Create linalg::IndexOps.
10731075
rewriter.setInsertionPointToStart(&linalgBody);

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
14231423
func.func @recursive_effect(%arg : tensor<1xf32>) {
14241424
%init = arith.constant dense<0.0> : tensor<1xf32>
14251425
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
1426-
(%in : f32) {
1426+
(%in : f32, %out: f32) {
14271427
vector.print %in : f32
14281428
linalg.yield %in : f32
14291429
}

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
386386

387387
// -----
388388

389-
// CHECK-LABEL: generalize_linalg_map
390-
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
389+
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
391390
%cst = arith.constant 0.000000e+00 : f32
392-
// CHECK: linalg.map
393-
// CHECK-NOT: linalg.generic
394-
linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
395-
() {
396-
linalg.yield %cst : f32
397-
}
391+
linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
398392
return
399393
}
400394

395+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
396+
397+
// CHECK: @generalize_linalg_map
398+
399+
// CHECK: linalg.generic
400+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
401+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
402+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
403+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
404+
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
405+
// CHECK: linalg.yield %[[ADD]] : f32
406+
401407
// -----
402408

403409
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
681681
%add = linalg.map
682682
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
683683
outs(%init:tensor<64xf32>)
684-
(%lhs_elem: f32, %rhs_elem: f32) {
684+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
685685
%0 = arith.addf %lhs_elem, %rhs_elem: f32
686686
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
687687
linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
694694
func.func @map_input_mapper_arity_mismatch(
695695
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
696696
-> tensor<64xf32> {
697-
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
697+
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
698698
%add = linalg.map
699699
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
700700
outs(%init:tensor<64xf32>)
701-
(%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
701+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
702702
%0 = arith.addf %lhs_elem, %rhs_elem: f32
703703
linalg.yield %0: f32
704704
}
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
714714
%add = linalg.map
715715
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
716716
outs(%init:tensor<64xf32>)
717-
(%lhs_elem: f64, %rhs_elem: f64) {
717+
(%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
718718
%0 = arith.addf %lhs_elem, %rhs_elem: f64
719719
linalg.yield %0: f64
720720
}
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
730730
%add = linalg.map
731731
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
732732
outs(%init:tensor<32xf32>)
733-
(%lhs_elem: f32, %rhs_elem: f32) {
733+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
734734
%0 = arith.addf %lhs_elem, %rhs_elem: f32
735735
linalg.yield %0: f32
736736
}

mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
339339
%add = linalg.map
340340
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
341341
outs(%init:tensor<64xf32>)
342-
(%lhs_elem: f32, %rhs_elem: f32) {
342+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
343343
%0 = arith.addf %lhs_elem, %rhs_elem: f32
344344
linalg.yield %0: f32
345345
}

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,15 +341,15 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
341341
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
342342
%add = linalg.map
343343
outs(%init:tensor<64xf32>)
344-
() {
344+
(%out: f32) {
345345
%0 = arith.constant 0.0: f32
346346
linalg.yield %0: f32
347347
}
348348
func.return %add : tensor<64xf32>
349349
}
350350
// CHECK-LABEL: func @map_no_inputs
351351
// CHECK: linalg.map outs
352-
// CHECK-NEXT: () {
352+
// CHECK-NEXT: (%[[OUT:.*]]: f32) {
353353
// CHECK-NEXT: arith.constant
354354
// CHECK-NEXT: linalg.yield
355355
// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
361361
%add = linalg.map
362362
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
363363
outs(%init:tensor<64xf32>)
364-
(%lhs_elem: f32, %rhs_elem: f32) {
364+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
365365
%0 = arith.addf %lhs_elem, %rhs_elem: f32
366366
linalg.yield %0: f32
367367
}
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
378378
linalg.map
379379
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
380380
outs(%init:memref<64xf32>)
381-
(%lhs_elem: f32, %rhs_elem: f32) {
381+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
382382
%0 = arith.addf %lhs_elem, %rhs_elem: f32
383383
linalg.yield %0: f32
384384
}
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
393393
%abs = linalg.map
394394
ins(%input:tensor<64xf32>)
395395
outs(%init:tensor<64xf32>)
396-
(%input_elem: f32) {
396+
(%input_elem: f32, %out: f32) {
397397
%0 = math.absf %input_elem: f32
398398
linalg.yield %0: f32
399399
}
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
408408
linalg.map
409409
ins(%input:memref<64xf32>)
410410
outs(%init:memref<64xf32>)
411-
(%input_elem: f32) {
411+
(%input_elem: f32, %out: f32) {
412412
%0 = math.absf %input_elem: f32
413413
linalg.yield %0: f32
414414
}
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
604604
%add = linalg.map
605605
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
606606
outs(%init:tensor<64xf32>)
607-
(%lhs_elem: f32, %rhs_elem: f32) {
607+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
608608
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
609609
linalg.yield %0: f32
610610
}
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
622622

623623
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
624624
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
625-
(%in_1: f32, %in_2: f32) {
625+
(%in_1: f32, %in_2: f32, %out: f32) {
626626
%1 = arith.maximumf %in_1, %in_2 : f32
627627
linalg.yield %in_1 : f32
628628
}
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
634634
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635635
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636636
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
637-
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
637+
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
638638
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
639639
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
640640
// CHECK-NEXT: }

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
356356
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
357357
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
358358
outs(%arg2 : memref<64xf32>)
359-
(%in: f32, %in_0: f32) {
359+
(%in: f32, %in_0: f32, %out: f32) {
360360
%0 = arith.addf %in, %in_0 : f32
361361
linalg.yield %0 : f32
362362
}

0 commit comments

Comments
 (0)