diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c862ec84fcbc5..4000b1b92fb8e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1478,12 +1478,11 @@ static void maybeInsertOperation(PyOperationRef &op, nb::object PyOperation::create(std::string_view name, std::optional> results, - std::optional> operands, + llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, const nb::object &maybeIp, bool inferType) { - llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; llvm::SmallVector, 4> mlirAttributes; @@ -1492,16 +1491,6 @@ nb::object PyOperation::create(std::string_view name, if (regions < 0) throw nb::value_error("number of regions must be >= 0"); - // Unpack/validate operands. - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue *operand : *operands) { - if (!operand) - throw nb::value_error("operand value cannot be None"); - mlirOperands.push_back(operand->get()); - } - } - // Unpack/validate results. if (results) { mlirResults.reserve(results->size()); @@ -1559,9 +1548,8 @@ nb::object PyOperation::create(std::string_view name, // point, exceptions cannot be thrown or else the state will leak. MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), location); - if (!mlirOperands.empty()) - mlirOperationStateAddOperands(&state, mlirOperands.size(), - mlirOperands.data()); + if (!operands.empty()) + mlirOperationStateAddOperands(&state, operands.size(), operands.data()); state.enableResultTypeInference = inferType; if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), @@ -1629,6 +1617,143 @@ void PyOperation::erase() { mlirOperationDestroy(operation); } +namespace { +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nb::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str() + .c_str()); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + nb::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +} // namespace + +/// Python wrapper for MlirOpResult. +class PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyOpResult &self) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in the IR"); + return self.getParentOperation().getObject(); + }); + c.def_prop_ro("result_number", [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }); + } +}; + +/// Returns the list of types of the values held by container. +template +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; + result.reserve(container.size()); + for (int i = 0, e = container.size(); i < e; ++i) { + result.push_back(mlirValueGetType(container.getElement(i).get())); + } + return result; +} + +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) result list is associated +/// with the operation whose results these are, and thus extends the lifetime of +/// this operation. +class PyOpResultList : public Sliceable { +public: + static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; + + PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumResults(operation->get()) + : length, + step), + operation(std::move(operation)) {} + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + c.def_prop_ro("owner", [](PyOpResultList &self) { + return self.operation->createOpView(); + }); + } + + PyOperationRef &getOperation() { return operation; } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumResults(operation->get()); + } + + PyOpResult getRawElement(intptr_t index) { + PyValue value(operation, mlirOperationGetResult(operation->get(), index)); + return PyOpResult(value); + } + + PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpResultList(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + //------------------------------------------------------------------------------ // PyOpView //------------------------------------------------------------------------------ @@ -1730,6 +1855,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, } } +static MlirValue getUniqueResult(MlirOperation operation) { + auto numResults = mlirOperationGetNumResults(operation); + if (numResults != 1) { + auto name = mlirIdentifierStr(mlirOperationGetName(operation)); + throw nb::value_error((Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)") + .str() + .c_str()); + } + return mlirOperationGetResult(operation, 0); +} + +static MlirValue getOpResultOrValue(nb::handle operand) { + if (operand.is_none()) { + throw nb::value_error("contained a None item"); + } + PyOperationBase *op; + if (nb::try_cast(operand, op)) { + return getUniqueResult(op->getOperation()); + } + PyOpResultList *opResultList; + if (nb::try_cast(operand, opResultList)) { + return getUniqueResult(opResultList->getOperation()->get()); + } + PyValue *value; + if (nb::try_cast(operand, value)) { + return value->get(); + } + throw nb::value_error("is not a Value"); +} + nb::object PyOpView::buildGeneric( std::string_view name, std::tuple opRegionSpec, nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, @@ -1780,16 +1939,14 @@ nb::object PyOpView::buildGeneric( } // Unpack operands. - std::vector operands; + llvm::SmallVector operands; operands.reserve(operands.size()); if (operandSegmentSpecObj.is_none()) { // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(nb::cast(it.value())); - if (!operands.back()) - throw nb::cast_error(); - } catch (nb::cast_error &err) { + operands.push_back(getOpResultOrValue(it.value())); + } catch (nb::builtin_exception &err) { throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") @@ -1815,29 +1972,31 @@ nb::object PyOpView::buildGeneric( int segmentSpec = std::get<1>(it.value()); if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. - try { - auto *operandValue = nb::cast(std::get<0>(it.value())); - if (operandValue) { - operands.push_back(operandValue); - operandSegmentLengths.push_back(1); - } else if (segmentSpec == 0) { - // Allowed to be optional. - operandSegmentLengths.push_back(0); - } else { - throw nb::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " of operation \"" + name + - "\" must be a Value (was None and operand is not optional)") - .str() - .c_str()); + auto &operand = std::get<0>(it.value()); + if (!operand.is_none()) { + try { + + operands.push_back(getOpResultOrValue(operand)); + } catch (nb::builtin_exception &err) { + throw nb::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (" + err.what() + ")") + .str() + .c_str()); } - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Operand ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Value (" + err.what() + - ")") - .str() - .c_str()); + + operandSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + operandSegmentLengths.push_back(0); + } else { + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1849,10 +2008,7 @@ nb::object PyOpView::buildGeneric( // Unpack the list. auto segment = nb::cast(std::get<0>(it.value())); for (nb::handle segmentItem : segment) { - operands.push_back(nb::cast(segmentItem)); - if (!operands.back()) { - throw nb::type_error("contained a None item"); - } + operands.push_back(getOpResultOrValue(segmentItem)); } operandSegmentLengths.push_back(nb::len(segment)); } @@ -2266,57 +2422,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, } namespace { -/// CRTP base class for Python MLIR values that subclass Value and should be -/// castable from it. The value hierarchy is one level deep and is not supposed -/// to accommodate other levels unless core MLIR changes. -template -class PyConcreteValue : public PyValue { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = nb::class_; - using IsAFunctionTy = bool (*)(MlirValue); - - PyConcreteValue() = default; - PyConcreteValue(PyOperationRef operationRef, MlirValue value) - : PyValue(operationRef, value) {} - PyConcreteValue(PyValue &orig) - : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} - - /// Attempts to cast the original value to the derived type and throws on - /// type mismatches. - static MlirValue castFrom(PyValue &orig) { - if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast value to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str() - .c_str()); - } - return orig.get(); - } - - /// Binds the Python module objects to functions of this class. - static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); - cls.def_static( - "isinstance", - [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }, - nb::arg("other_value")); - cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; /// Python wrapper for MlirBlockArgument. class PyBlockArgument : public PyConcreteValue { @@ -2342,39 +2447,6 @@ class PyBlockArgument : public PyConcreteValue { } }; -/// Python wrapper for MlirOpResult. -class PyOpResult : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; - static constexpr const char *pyClassName = "OpResult"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation().getObject(); - }); - c.def_prop_ro("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); - } -}; - -/// Returns the list of types of the values held by container. -template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; - result.reserve(container.size()); - for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back(mlirValueGetType(container.getElement(i).get())); - } - return result; -} - /// A list of block arguments. Internally, these are stored as consecutive /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in @@ -2481,53 +2553,6 @@ class PyOpOperandList : public Sliceable { PyOperationRef operation; }; -/// A list of operation results. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) result list is associated -/// with the operation whose results these are, and thus extends the lifetime of -/// this operation. -class PyOpResultList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpResultList"; - using SliceableT = Sliceable; - - PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumResults(operation->get()) - : length, - step), - operation(std::move(operation)) {} - - static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - c.def_prop_ro("owner", [](PyOpResultList &self) { - return self.operation->createOpView(); - }); - } - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; - - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumResults(operation->get()); - } - - PyOpResult getRawElement(intptr_t index) { - PyValue value(operation, mlirOperationGetResult(operation->get(), index)); - return PyOpResult(value); - } - - PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpResultList(operation, startIndex, length, step); - } - - PyOperationRef operation; -}; - /// A list of operation successors. Internally, these are stored as consecutive /// elements, random access is cheap. The (returned) successor list is /// associated with the operation whose successors these are, and thus extends @@ -3108,20 +3133,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); - auto numResults = mlirOperationGetNumResults(operation); - if (numResults != 1) { - auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw nb::value_error( - (Twine("Cannot call .result on operation ") + - StringRef(name.data, name.length) + " which has " + - Twine(numResults) + - " results (it is only valid for operations with a " - "single result)") - .str() - .c_str()); - } - return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)) + return PyOpResult(operation.getRef(), getUniqueResult(operation)) .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " @@ -3218,14 +3230,36 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("walk_order") = MlirWalkPostOrder); nb::class_(m, "Operation") - .def_static("create", &PyOperation::create, nb::arg("name"), - nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), - nb::arg("ip").none() = nb::none(), - nb::arg("infer_type") = false, kOperationCreateDocstring) + .def_static( + "create", + [](std::string_view name, + std::optional> results, + std::optional> operands, + std::optional attributes, + std::optional> successors, int regions, + DefaultingPyLocation location, const nb::object &maybeIp, + bool inferType) { + // Unpack/validate operands. + llvm::SmallVector mlirOperands; + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue *operand : *operands) { + if (!operand) + throw nb::value_error("operand value cannot be None"); + mlirOperands.push_back(operand->get()); + } + } + + return PyOperation::create(name, results, mlirOperands, attributes, + successors, regions, location, maybeIp, + inferType); + }, + nb::arg("name"), nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fd70ac7ac6ec3..dd6e7ef912374 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -686,7 +686,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an operation. See corresponding python docstring. static nanobind::object create(std::string_view name, std::optional> results, - std::optional> operands, + llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, const nanobind::object &ip, diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 5b67ab03d6f49..d3dbdc604ef4c 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -115,7 +115,10 @@ def get_op_results_or_values( _cext.ir.Operation, _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], ] -) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: +) -> _Union[ + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + _cext.ir.OpResultList, +]: """Returns the given sequence of values or the results of the given op. This is useful to implement op constructors so that they can take other ops as @@ -127,7 +130,7 @@ def get_op_results_or_values( elif isinstance(arg, _cext.ir.Operation): return arg.results else: - return [get_op_result_or_value(element) for element in arg] + return arg def get_op_result_or_op_results( diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 25833779c2f71..72963cac64d54 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -27,8 +27,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) - // CHECK: operands.append(_get_op_result_or_value(non_variadic)) - // CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None) + // CHECK: operands.append(non_variadic) + // CHECK: operands.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__( // CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, @@ -173,8 +173,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) - // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) + // CHECK: operands.append(_gen_arg_0) + // CHECK: operands.append(_gen_arg_2) // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = (is_ @@ -307,9 +307,9 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) - // CHECK: operands.append(_get_op_result_or_value(f32)) - // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) + // CHECK: operands.append(_gen_arg_0) + // CHECK: operands.append(f32) + // CHECK: operands.append(_gen_arg_2) // CHECK: results.append(i32) // CHECK: results.append(_gen_res_1) // CHECK: results.append(i64) @@ -349,8 +349,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(non_optional)) - // CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional)) + // CHECK: operands.append(non_optional) + // CHECK: if optional is not None: operands.append(optional) // CHECK: _ods_successors = None // CHECK: super().__init__( // CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS @@ -380,7 +380,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(non_variadic)) + // CHECK: operands.append(non_variadic) // CHECK: operands.extend(_get_op_results_or_values(variadic)) // CHECK: _ods_successors = None // CHECK: super().__init__( @@ -445,7 +445,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(in_)) + // CHECK: operands.append(in_) // CHECK: _ods_successors = None // CHECK: super().__init__( // CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS @@ -547,8 +547,8 @@ def SimpleOp : TestOp<"simple"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: operands.append(_get_op_result_or_value(i32)) - // CHECK: operands.append(_get_op_result_or_value(f32)) + // CHECK: operands.append(i32) + // CHECK: operands.append(f32) // CHECK: results.append(i64) // CHECK: results.append(f64) // CHECK: _ods_successors = None diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index e1540d1750ff1..604d2376052a8 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -37,7 +37,6 @@ from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_op_results as _get_op_result_or_op_results, - get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, segmented_accessor as _ods_segmented_accessor, ) @@ -501,17 +500,15 @@ constexpr const char *initTemplate = R"Py( /// Template for appending a single element to the operand/result list. /// {0} is the field name. -constexpr const char *singleOperandAppendTemplate = - "operands.append(_get_op_result_or_value({0}))"; +constexpr const char *singleOperandAppendTemplate = "operands.append({0})"; constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// Template for appending an optional element to the operand/result list. /// {0} is the field name. constexpr const char *optionalAppendOperandTemplate = - "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; + "if {0} is not None: operands.append({0})"; constexpr const char *optionalAppendAttrSizedOperandsTemplate = - "operands.append(_get_op_result_or_value({0}) if {0} is not None else " - "None)"; + "operands.append({0})"; constexpr const char *optionalAppendResultTemplate = "if {0} is not None: results.append({0})";