Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
}

// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
Expand Down Expand Up @@ -157,7 +157,7 @@ def AttributedOp : TestOp<"attributed_op"> {
}

// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
Expand Down Expand Up @@ -193,7 +193,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
}

// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
Expand All @@ -217,7 +217,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
}

// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
Expand All @@ -235,7 +235,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
}

// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
Expand All @@ -262,7 +262,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))

// CHECK: def empty(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
// CHECK: return EmptyOp(loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
Expand All @@ -275,7 +275,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
}

// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
Expand All @@ -288,7 +288,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
}

// CHECK: def infer_result_types_op(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
Expand Down Expand Up @@ -326,7 +326,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
}

// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
Expand Down Expand Up @@ -357,7 +357,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
}

// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
Expand Down Expand Up @@ -389,7 +389,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
}

// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicResultOp(_ods_ir.OpView):
Expand Down Expand Up @@ -446,7 +446,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
}

// CHECK: def python_keyword(in_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
Expand All @@ -460,7 +460,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
}

// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
Expand Down Expand Up @@ -497,7 +497,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
}

// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
Expand Down Expand Up @@ -563,7 +563,7 @@ def SimpleOp : TestOp<"simple"> {
}

// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results

// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
Expand All @@ -590,7 +590,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
}

// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: class VariadicRegionOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
Expand All @@ -613,7 +613,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
}

// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
Expand All @@ -622,7 +622,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}

// CHECK: def _123with__special_characters(*, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class WithSuccessorsOp(_ods_ir.OpView):
Expand All @@ -637,4 +637,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
}

// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)
33 changes: 26 additions & 7 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ constexpr const char *regionAccessorTemplate = R"Py(
)Py";

constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
return {1}({3}){5}
)Py";

constexpr const char *valueBuilderVariadicTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
)Py";
Expand Down Expand Up @@ -992,15 +997,29 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << formatv(
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
std::string nameWithoutDialect = sanitizeName(
op.getOperationName().substr(op.getOperationName().find('.') + 1));
std::string params = llvm::join(valueBuilderParams, ", ");
std::string args = llvm::join(opBuilderArgs, ", ");
const char *type =
(op.getNumResults() > 1
? "_Sequence[_ods_ir.Value]"
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
if (op.getNumVariableLengthResults() > 0) {
os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
op.getCppClassName(), params, args, type);
} else {
const char *results;
if (op.getNumResults() == 0) {
results = "";
} else if (op.getNumResults() == 1) {
results = ".result";
} else {
results = ".results";
}
os << formatv(valueBuilderTemplate, nameWithoutDialect,
op.getCppClassName(), params, args, type, results);
}
}

/// Emits bindings for a specific Op to the given output stream.
Expand Down
Loading