Skip to content

Commit 260eca3

Browse files
ricejasonfermilindwalekar
authored andcommitted
[MLIR] Forward generated OpTy::create arguments (llvm#170012)
The recent changes in the MLIR TableGen interface for generated OpTy::build functions involves a new OpTy::create function that is generated passing arguments without forwarding. This is problematic with arguments that are move only such as `std::unique_ptr`. My particular use case involves `std::unique_ptr<mlir::Region>` which is desirable as the `mlir::OperationState` object accepts calls to `addRegion(std::unique_ptr<mlir::Region>`. In Discord, the use of `extraClassDeclarations` was suggested which I may go with regardless since I still have to define the builder function anyways, but perhaps you would consider this trivial change as it supports a broader class of argument types for this approach. Consider the declaration in TableGen: ``` let builders = [ OpBuilder<(ins "::mlir::Value":$cdr, "::mlir::ValueRange":$packs, "std::unique_ptr<::mlir::Region>":$body)> ]; ``` Which currently generates: ```cpp ExpandPacksOp ExpandPacksOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value cdr, ::mlir::ValueRange packs, std::unique_ptr<::mlir::Region> body) { ::mlir::OperationState __state__(location, getOperationName()); build(builder, __state__, std::forward<decltype(cdr)>(cdr), std::forward<decltype(packs)>(packs), std::forward<decltype(body)>(body)); auto __res__ = ::llvm::dyn_cast<ExpandPacksOp>(builder.create(__state__)); assert(__res__ && "builder didn't return the right type"); return __res__; } ``` With this change it will generate: ```cpp ExpandPacksOp ExpandPacksOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value cdr, ::mlir::ValueRange packs, std::unique_ptr<::mlir::Region>&&body) { ::mlir::OperationState __state__(location, getOperationName()); build(builder, __state__, static_cast<decltype(cdr)>(cdr), std::forward<decltype(packs)>(packs), std::forward<decltype(body)>(body)); auto __res__ = ::llvm::dyn_cast<ExpandPacksOp>(builder.create(__state__)); assert(__res__ && "builder didn't return the right type"); return __res__; } ``` Another option could be to make this function a template but then it would not be hidden in the generated translation unit. I don't know if that was the original intent. Thank you for your consideration.
1 parent d186c40 commit 260eca3

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

mlir/test/mlir-tblgen/op-decl-and-defs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,14 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint",
236236

237237
// DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) {
238238
// DEFS: ::mlir::OperationState __state__(location, getOperationName());
239-
// DEFS: build(builder, __state__, a);
239+
// DEFS: build(builder, __state__, std::forward<decltype(a)>(a));
240240
// DEFS: auto __res__ = ::llvm::dyn_cast<FOp>(builder.create(__state__));
241241
// DEFS: assert(__res__ && "builder didn't return the right type");
242242
// DEFS: return __res__;
243243
// DEFS: }
244244

245245
// DEFS: FOp FOp::create(::mlir::ImplicitLocOpBuilder &builder, ::mlir::Value a) {
246-
// DEFS: return create(builder, builder.getLoc(), a);
246+
// DEFS: return create(builder, builder.getLoc(), std::forward<decltype(a)>(a));
247247
// DEFS: }
248248

249249
def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2616,7 +2616,14 @@ void OpEmitter::genInlineCreateBody(
26162616
std::string nonBuilderStateArgs = "";
26172617
if (!nonBuilderStateArgsList.empty()) {
26182618
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2619-
interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
2619+
interleave(
2620+
nonBuilderStateArgsList,
2621+
[&](StringRef name) {
2622+
nonBuilderStateArgsOS << "std::forward<decltype(" << name << ")>("
2623+
<< name << ')';
2624+
},
2625+
[&] { nonBuilderStateArgsOS << ", "; });
2626+
26202627
nonBuilderStateArgs = ", " + nonBuilderStateArgs;
26212628
}
26222629
cWithLoc->body() << llvm::formatv(inlineCreateBody, locParamName,

0 commit comments

Comments
 (0)