Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
return getDpsInputOperands();
}

bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
Expand Down
37 changes: 24 additions & 13 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
for (Value v : getRegionOutputArgs())
setNameFn(v, "init");
}

void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
Expand All @@ -1495,14 +1497,14 @@ void MapOp::build(

if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
inputs, /*outputs=*/{}, bodyBuild);
inputs, /*outputs=*/{init}, bodyBuild);
}

static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
bool initFirst = false) {
bool initFirst = false, bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just the option with least amount of changes but can refactor if necessary

OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
Expand All @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
payloadOpOperands.push_back(block.getArguments().back());
if (mapInit)
payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
block.getArguments().end()};
block.getArguments().end() - int(!mapInit)};
}

Operation *payloadOp = b.create(
Expand Down Expand Up @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
payloadOpAttrs,
ArrayRef(result.operands).drop_back());
payloadOpAttrs, ArrayRef(result.operands), false,
false);
else
result.addRegion();
} else {
Expand All @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

static bool canUseShortForm(Block *body, bool initFirst = false) {
static bool canUseShortForm(Block *body, bool initFirst = false,
bool mapInit = true) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, this was the option with least amount of changes. will refactor if desired

// `intFirst == true` implies that we want to map init arg
if (initFirst && !mapInit)
return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:

Expand All @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;

// 3) If `initFirst` is true (e.g., for reduction ops), the init block
Expand All @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
llvm::zip(payload.getOperands(),
body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
Expand Down Expand Up @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {

void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
bool useShortForm = canUseShortForm(mapper);
bool useShortForm =
canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
Expand All @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();

// Checks if the number of `inputs` match the arity of the `mapper` region.
if (getInputs().size() != blockArgs.size())
// Checks if the number of `inputs` + `init` match the arity of the `mapper`
// region.
if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
<< getInputs().size() << " and " << blockArgs.size();
<< getInputs().size() + 1 << " and "
<< blockArgs.size();

// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;

static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
// Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
// trivially generalize a `linalg.map`, as it does not use the output as
// region arguments in the block.
if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
// Bailout if `linalgOp` is already a generic.
if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
linalgBody.addArgument(tensorType.getElementType(), loc);

// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
Expand Down Expand Up @@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
linalgBody.addArgument(tensorType.getElementType(), loc);

// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
func.func @recursive_effect(%arg : tensor<1xf32>) {
%init = arith.constant dense<0.0> : tensor<1xf32>
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
(%in : f32) {
(%in : f32, %out: f32) {
vector.print %in : f32
linalg.yield %in : f32
}
Expand Down
22 changes: 14 additions & 8 deletions mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem

// -----

// CHECK-LABEL: generalize_linalg_map
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
%cst = arith.constant 0.000000e+00 : f32
// CHECK: linalg.map
// CHECK-NOT: linalg.generic
linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
() {
linalg.yield %cst : f32
}
linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
return
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

// CHECK: @generalize_linalg_map

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
// CHECK: linalg.yield %[[ADD]] : f32

// -----

func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
linalg.yield %0, %0: f32, f32
Expand All @@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
func.func @map_input_mapper_arity_mismatch(
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-> tensor<64xf32> {
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
Expand All @@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f64, %rhs_elem: f64) {
(%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f64
linalg.yield %0: f64
}
Expand All @@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%init:tensor<32xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
Expand Down
18 changes: 9 additions & 9 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,15 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
%add = linalg.map
outs(%init:tensor<64xf32>)
() {
(%out: f32) {
%0 = arith.constant 0.0: f32
linalg.yield %0: f32
}
func.return %add : tensor<64xf32>
}
// CHECK-LABEL: func @map_no_inputs
// CHECK: linalg.map outs
// CHECK-NEXT: () {
// CHECK-NEXT: (%[[OUT:.*]]: f32) {
// CHECK-NEXT: arith.constant
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
Expand All @@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
Expand All @@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
linalg.map
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
outs(%init:memref<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
Expand All @@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
%abs = linalg.map
ins(%input:tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%input_elem: f32) {
(%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
Expand All @@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
linalg.map
ins(%input:memref<64xf32>)
outs(%init:memref<64xf32>)
(%input_elem: f32) {
(%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
Expand Down Expand Up @@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
(%lhs_elem: f32, %rhs_elem: f32) {
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
linalg.yield %0: f32
}
Expand All @@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,

func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
(%in_1: f32, %in_2: f32) {
(%in_1: f32, %in_2: f32, %out: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
Expand All @@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
outs(%arg2 : memref<64xf32>)
(%in: f32, %in_0: f32) {
(%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tensor/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
// CHECK: () {
// CHECK: (%[[INIT:.*]]: f32) {
// CHECK: linalg.yield %[[F]] : f32
// CHECK: }
// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----

func.func @map(%lhs: memref<64xf32>,
%rhs: memref<64xf32>, %out: memref<64xf32>) {
%rhs: memref<64xf32>, %init: memref<64xf32>) {
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
outs(%out : memref<64xf32>)
(%in: f32, %in_0: f32) {
outs(%init : memref<64xf32>)
(%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
Expand Down
Loading