Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 714d0d7

Browse files
authored
[MLIR][Python] fix generated value builder type hints (#158449)
Currently the type hints on the returns of the "value builders" are `ir.Value`, `Sequence[ir.Value]`, and `ir.Operation`, none of which are correct. The correct possibilities are `ir.OpResult`, `ir.OpResultList`, the OpView class itself (e.g., `AttrSizedResultsOp`) or the union of the 3 (for variadic results). This PR fixes those hints.
1 parent fb84e3a commit 714d0d7

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ from ._ods_common import _cext as _ods_cext
3636
from ._ods_common import (
3737
equally_sized_accessor as _ods_equally_sized_accessor,
3838
get_default_loc_context as _ods_get_default_loc_context,
39-
get_op_result_or_op_results as _get_op_result_or_op_results,
4039
get_op_results_or_values as _get_op_results_or_values,
4140
segmented_accessor as _ods_segmented_accessor,
4241
)
@@ -276,8 +275,9 @@ def {0}({2}) -> {4}:
276275
)Py";
277276

278277
constexpr const char *valueBuilderVariadicTemplate = R"Py(
279-
def {0}({2}) -> {4}:
280-
return _get_op_result_or_op_results({1}({3}))
278+
def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
279+
op = {1}({3}); results = op.results
280+
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
281281
)Py";
282282

283283
static llvm::cl::OptionCategory
@@ -1013,21 +1013,18 @@ static void emitValueBuilder(const Operator &op,
10131013
nameWithoutDialect += "_";
10141014
std::string params = llvm::join(valueBuilderParams, ", ");
10151015
std::string args = llvm::join(opBuilderArgs, ", ");
1016-
const char *type =
1017-
(op.getNumResults() > 1
1018-
? "_Sequence[_ods_ir.Value]"
1019-
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
1020-
if (op.getNumVariableLengthResults() > 0) {
1016+
if (op.getNumVariableLengthResults()) {
10211017
os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
1022-
op.getCppClassName(), params, args, type);
1018+
op.getCppClassName(), params, args);
10231019
} else {
1024-
const char *results;
1025-
if (op.getNumResults() == 0) {
1026-
results = "";
1020+
std::string type = op.getCppClassName().str();
1021+
const char *results = "";
1022+
if (op.getNumResults() > 1) {
1023+
type = "_ods_ir.OpResultList";
1024+
results = ".results";
10271025
} else if (op.getNumResults() == 1) {
1026+
type = "_ods_ir.OpResult";
10281027
results = ".result";
1029-
} else {
1030-
results = ".results";
10311028
}
10321029
os << formatv(valueBuilderTemplate, nameWithoutDialect,
10331030
op.getCppClassName(), params, args, type, results);

0 commit comments

Comments
 (0)