Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit a600a09

Browse files
authored
[MLIR][Python] Add optional results parameter for building op with inferable result types (#156818)
Currently in MLIR python bindings, operations with inferable result types (e.g. with `InferTypeOpInterface` or `SameOperandsAndResultType`) will generate such builder functions: ```python def my_op(arg1, arg2 .. argN, *, loc=None, ip=None): ... # result types will be inferred automatically ``` However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method `inferResultTypes(..)` can return a failure and then we cannot build the op in that way. Also, in the C++ side we have multiple `build` methods for both explicitly specify the result types and automatically inferring them. In this PR, we change the signature of this builder function to: ```python def my_op(arg1, arg2 .. argN, *, results=None, loc=None, ip=None): ... # result types will be inferred automatically if results is None ``` If the `results` is not provided, it will be inferred automatically, otherwise the provided result types will be utilized. Also, `__init__` methods of the generated op classes are changed correspondingly. Note that for operations without inferable result types, the signature remain unchanged, i.e. `def my_op(res1 .. resN, arg1 .. argN, *, loc=None, ip=None)`. --- Previously I have considered an approach like `my_op(arg, *, res1=None, res2=None, loc=None, ip=None)`, but I quickly realized it had some issues. For example, if the user only provides some of the arguments—say `my_op(v1, res1=i32)`—this could lead to problems. Moreover, we don’t seem to have a mechanism for inferring only part of result types. A unified `results` parameter seems to be more simple and straightforward.
1 parent 7b42861 commit a600a09

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,6 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
492492
constexpr const char *initTemplate = R"Py(
493493
def __init__(self, {0}):
494494
operands = []
495-
results = []
496495
attributes = {{}
497496
regions = None
498497
{1}
@@ -738,18 +737,24 @@ populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
738737
}
739738
}
740739

741-
/// Python code template for deriving the operation result types from its
742-
/// attribute:
740+
/// Python code template of generating result types for
741+
/// FirstAttrDerivedResultType trait
743742
/// - {0} is the name of the attribute from which to derive the types.
744-
constexpr const char *deriveTypeFromAttrTemplate =
745-
R"Py(_ods_result_type_source_attr = attributes["{0}"]
746-
_ods_derived_result_type = (
743+
/// - {1} is the number of results.
744+
constexpr const char *firstAttrDerivedResultTypeTemplate =
745+
R"Py(if results is None:
746+
_ods_result_type_source_attr = attributes["{0}"]
747+
_ods_derived_result_type = (
747748
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
748749
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
749-
_ods_result_type_source_attr.type))Py";
750+
_ods_result_type_source_attr.type)
751+
results = [_ods_derived_result_type] * {1})Py";
750752

751-
/// Python code template appending {0} type {1} times to the results list.
752-
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
753+
/// Python code template of generating result types for
754+
/// SameOperandsAndResultType trait
755+
/// - {0} is the number of results.
756+
constexpr const char *sameOperandsAndResultTypeTemplate =
757+
R"Py(if results is None: results = [operands[0].type] * {0})Py";
753758

754759
/// Appends the given multiline string as individual strings into
755760
/// `builderLines`.
@@ -768,29 +773,30 @@ static void appendLineByLine(StringRef string,
768773
static void
769774
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
770775
SmallVectorImpl<std::string> &builderLines) {
771-
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
772-
773776
if (hasSameArgumentAndResultTypes(op)) {
774-
builderLines.push_back(formatv(appendSameResultsTemplate,
775-
"operands[0].type", op.getNumResults()));
777+
appendLineByLine(
778+
formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
779+
builderLines);
776780
return;
777781
}
778782

779783
if (hasFirstAttrDerivedResultTypes(op)) {
780784
const NamedAttribute &firstAttr = op.getAttribute(0);
781785
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
782786
"from which the type is derived");
783-
appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
787+
appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
788+
op.getNumResults())
789+
.str(),
784790
builderLines);
785-
builderLines.push_back(formatv(appendSameResultsTemplate,
786-
"_ods_derived_result_type",
787-
op.getNumResults()));
788791
return;
789792
}
790793

791794
if (hasInferTypeInterface(op))
792795
return;
793796

797+
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
798+
builderLines.push_back("results = []");
799+
794800
// For each element, find or generate a name.
795801
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
796802
const NamedTypeConstraint &element = op.getResult(i);
@@ -909,6 +915,9 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
909915
functionArgs.push_back(builderArgs[i]);
910916
}
911917
}
918+
if (canInferType(op)) {
919+
functionArgs.push_back("results=None");
920+
}
912921
functionArgs.push_back("loc=None");
913922
functionArgs.push_back("ip=None");
914923

@@ -918,8 +927,7 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
918927
initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
919928
initArgs.push_back("self._ODS_RESULT_SEGMENTS");
920929
initArgs.push_back("attributes=attributes");
921-
if (!hasInferTypeInterface(op))
922-
initArgs.push_back("results=results");
930+
initArgs.push_back("results=results");
923931
initArgs.push_back("operands=operands");
924932
initArgs.push_back("successors=_ods_successors");
925933
initArgs.push_back("regions=regions");

0 commit comments

Comments
 (0)