Skip to content

Conversation

@hawkinsp
Copy link
Contributor

This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this.

This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this.

Most of the diff here is simply reordering the code in IRCore.cpp.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir labels Jan 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Peter Hawkins (hawkinsp)

Changes

This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this.

This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this.

Most of the diff here is simply reordering the code in IRCore.cpp.


Patch is 29.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123953.diff

5 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+232-199)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+1-1)
  • (modified) mlir/python/mlir/dialects/_ods_common.py (+2-2)
  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+13-13)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+3-6)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c862ec84fcbc55..53f2031d5df055 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1478,12 +1478,11 @@ static void maybeInsertOperation(PyOperationRef &op,
 
 nb::object PyOperation::create(std::string_view name,
                                std::optional<std::vector<PyType *>> results,
-                               std::optional<std::vector<PyValue *>> operands,
+                               const llvm::SmallVector<MlirValue, 4> &operands,
                                std::optional<nb::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
                                int regions, DefaultingPyLocation location,
                                const nb::object &maybeIp, bool inferType) {
-  llvm::SmallVector<MlirValue, 4> mlirOperands;
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -1492,16 +1491,6 @@ nb::object PyOperation::create(std::string_view name,
   if (regions < 0)
     throw nb::value_error("number of regions must be >= 0");
 
-  // Unpack/validate operands.
-  if (operands) {
-    mlirOperands.reserve(operands->size());
-    for (PyValue *operand : *operands) {
-      if (!operand)
-        throw nb::value_error("operand value cannot be None");
-      mlirOperands.push_back(operand->get());
-    }
-  }
-
   // Unpack/validate results.
   if (results) {
     mlirResults.reserve(results->size());
@@ -1559,9 +1548,8 @@ nb::object PyOperation::create(std::string_view name,
   // point, exceptions cannot be thrown or else the state will leak.
   MlirOperationState state =
       mlirOperationStateGet(toMlirStringRef(name), location);
-  if (!mlirOperands.empty())
-    mlirOperationStateAddOperands(&state, mlirOperands.size(),
-                                  mlirOperands.data());
+  if (!operands.empty())
+    mlirOperationStateAddOperands(&state, operands.size(), operands.data());
   state.enableResultTypeInference = inferType;
   if (!mlirResults.empty())
     mlirOperationStateAddResults(&state, mlirResults.size(),
@@ -1629,6 +1617,142 @@ void PyOperation::erase() {
   mlirOperationDestroy(operation);
 }
 
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy> class PyConcreteValue : public PyValue {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = nb::class_<DerivedTy, PyValue>;
+  using IsAFunctionTy = bool (*)(MlirValue);
+
+  PyConcreteValue() = default;
+  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+      : PyValue(operationRef, value) {}
+  PyConcreteValue(PyValue &orig)
+      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+  /// Attempts to cast the original value to the derived type and throws on
+  /// type mismatches.
+  static MlirValue castFrom(PyValue &orig) {
+    if (!DerivedTy::isaFunction(orig.get())) {
+      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+      throw nb::value_error((Twine("Cannot cast value to ") +
+                             DerivedTy::pyClassName + " (from " + origRepr +
+                             ")")
+                                .str()
+                                .c_str());
+    }
+    return orig.get();
+  }
+
+  /// Binds the Python module objects to functions of this class.
+  static void bind(nb::module_ &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+    cls.def_static(
+        "isinstance",
+        [](PyValue &otherValue) -> bool {
+          return DerivedTy::isaFunction(otherValue);
+        },
+        nb::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+  static constexpr const char *pyClassName = "OpResult";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("owner", [](PyOpResult &self) {
+      assert(
+          mlirOperationEqual(self.getParentOperation()->get(),
+                             mlirOpResultGetOwner(self.get())) &&
+          "expected the owner of the value in Python to match that in the IR");
+      return self.getParentOperation().getObject();
+    });
+    c.def_prop_ro("result_number", [](PyOpResult &self) {
+      return mlirOpResultGetResultNumber(self.get());
+    });
+  }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+                                           PyMlirContextRef &context) {
+  std::vector<MlirType> result;
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
+    result.push_back(mlirValueGetType(container.getElement(i).get()));
+  }
+  return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+  static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+                 intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumResults(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+    c.def_prop_ro("owner", [](PyOpResultList &self) {
+      return self.operation->createOpView();
+    });
+  }
+
+  PyOperationRef &getOperation() { return operation; }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
+    operation->checkValid();
+    return mlirOperationGetNumResults(operation->get());
+  }
+
+  PyOpResult getRawElement(intptr_t index) {
+    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+    return PyOpResult(value);
+  }
+
+  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpResultList(operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+};
+
 //------------------------------------------------------------------------------
 // PyOpView
 //------------------------------------------------------------------------------
@@ -1730,6 +1854,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
   }
 }
 
+static MlirValue getUniqueResult(MlirOperation operation) {
+  auto numResults = mlirOperationGetNumResults(operation);
+  if (numResults != 1) {
+    auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+    throw nb::value_error((Twine("Cannot call .result on operation ") +
+                           StringRef(name.data, name.length) + " which has " +
+                           Twine(numResults) +
+                           " results (it is only valid for operations with a "
+                           "single result)")
+                              .str()
+                              .c_str());
+  }
+  return mlirOperationGetResult(operation, 0);
+}
+
+static MlirValue getOpResultOrValue(nb::handle operand) {
+  if (operand.is_none()) {
+    throw nb::value_error("contained a None item");
+  }
+  PyOperationBase *op;
+  if (nb::try_cast<PyOperationBase *>(operand, op)) {
+    return getUniqueResult(op->getOperation());
+  }
+  PyOpResultList *opResultList;
+  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+    return getUniqueResult(opResultList->getOperation()->get());
+  }
+  PyValue *value;
+  if (nb::try_cast<PyValue *>(operand, value)) {
+    return value->get();
+  }
+  throw nb::value_error("is not a Value");
+}
+
 nb::object PyOpView::buildGeneric(
     std::string_view name, std::tuple<int, bool> opRegionSpec,
     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
@@ -1780,16 +1938,14 @@ nb::object PyOpView::buildGeneric(
   }
 
   // Unpack operands.
-  std::vector<PyValue *> operands;
+  llvm::SmallVector<MlirValue, 4> operands;
   operands.reserve(operands.size());
   if (operandSegmentSpecObj.is_none()) {
     // Non-sized operand unpacking.
     for (const auto &it : llvm::enumerate(operandList)) {
       try {
-        operands.push_back(nb::cast<PyValue *>(it.value()));
-        if (!operands.back())
-          throw nb::cast_error();
-      } catch (nb::cast_error &err) {
+        operands.push_back(getOpResultOrValue(it.value()));
+      } catch (nb::builtin_exception &err) {
         throw nb::value_error((llvm::Twine("Operand ") +
                                llvm::Twine(it.index()) + " of operation \"" +
                                name + "\" must be a Value (" + err.what() + ")")
@@ -1815,29 +1971,31 @@ nb::object PyOpView::buildGeneric(
       int segmentSpec = std::get<1>(it.value());
       if (segmentSpec == 1 || segmentSpec == 0) {
         // Unpack unary element.
-        try {
-          auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
-          if (operandValue) {
-            operands.push_back(operandValue);
-            operandSegmentLengths.push_back(1);
-          } else if (segmentSpec == 0) {
-            // Allowed to be optional.
-            operandSegmentLengths.push_back(0);
-          } else {
-            throw nb::value_error(
-                (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
-                 " of operation \"" + name +
-                 "\" must be a Value (was None and operand is not optional)")
-                    .str()
-                    .c_str());
+        auto &operand = std::get<0>(it.value());
+        if (!operand.is_none()) {
+          try {
+
+            operands.push_back(getOpResultOrValue(operand));
+          } catch (nb::builtin_exception &err) {
+            throw nb::value_error((llvm::Twine("Operand ") +
+                                   llvm::Twine(it.index()) +
+                                   " of operation \"" + name +
+                                   "\" must be a Value (" + err.what() + ")")
+                                      .str()
+                                      .c_str());
           }
-        } catch (nb::cast_error &err) {
-          throw nb::value_error((llvm::Twine("Operand ") +
-                                 llvm::Twine(it.index()) + " of operation \"" +
-                                 name + "\" must be a Value (" + err.what() +
-                                 ")")
-                                    .str()
-                                    .c_str());
+
+          operandSegmentLengths.push_back(1);
+        } else if (segmentSpec == 0) {
+          // Allowed to be optional.
+          operandSegmentLengths.push_back(0);
+        } else {
+          throw nb::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " of operation \"" + name +
+               "\" must be a Value (was None and operand is not optional)")
+                  .str()
+                  .c_str());
         }
       } else if (segmentSpec == -1) {
         // Unpack sequence by appending.
@@ -1849,10 +2007,7 @@ nb::object PyOpView::buildGeneric(
             // Unpack the list.
             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
             for (nb::handle segmentItem : segment) {
-              operands.push_back(nb::cast<PyValue *>(segmentItem));
-              if (!operands.back()) {
-                throw nb::type_error("contained a None item");
-              }
+              operands.push_back(getOpResultOrValue(segmentItem));
             }
             operandSegmentLengths.push_back(nb::len(segment));
           }
@@ -2266,57 +2421,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
 }
 
 namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  // and redefine bindDerived.
-  using ClassTy = nb::class_<DerivedTy, PyValue>;
-  using IsAFunctionTy = bool (*)(MlirValue);
-
-  PyConcreteValue() = default;
-  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
-      : PyValue(operationRef, value) {}
-  PyConcreteValue(PyValue &orig)
-      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
-  /// Attempts to cast the original value to the derived type and throws on
-  /// type mismatches.
-  static MlirValue castFrom(PyValue &orig) {
-    if (!DerivedTy::isaFunction(orig.get())) {
-      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
-      throw nb::value_error((Twine("Cannot cast value to ") +
-                             DerivedTy::pyClassName + " (from " + origRepr +
-                             ")")
-                                .str()
-                                .c_str());
-    }
-    return orig.get();
-  }
-
-  /// Binds the Python module objects to functions of this class.
-  static void bind(nb::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
-    cls.def_static(
-        "isinstance",
-        [](PyValue &otherValue) -> bool {
-          return DerivedTy::isaFunction(otherValue);
-        },
-        nb::arg("other_value"));
-    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
 
 /// Python wrapper for MlirBlockArgument.
 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
@@ -2342,39 +2446,6 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
   }
 };
 
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
-  static constexpr const char *pyClassName = "OpResult";
-  using PyConcreteValue::PyConcreteValue;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("owner", [](PyOpResult &self) {
-      assert(
-          mlirOperationEqual(self.getParentOperation()->get(),
-                             mlirOpResultGetOwner(self.get())) &&
-          "expected the owner of the value in Python to match that in the IR");
-      return self.getParentOperation().getObject();
-    });
-    c.def_prop_ro("result_number", [](PyOpResult &self) {
-      return mlirOpResultGetResultNumber(self.get());
-    });
-  }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<MlirType> getValueTypes(Container &container,
-                                           PyMlirContextRef &context) {
-  std::vector<MlirType> result;
-  result.reserve(container.size());
-  for (int i = 0, e = container.size(); i < e; ++i) {
-    result.push_back(mlirValueGetType(container.getElement(i).get()));
-  }
-  return result;
-}
-
 /// A list of block arguments. Internally, these are stored as consecutive
 /// elements, random access is cheap. The argument list is associated with the
 /// operation that contains the block (detached blocks are not allowed in
@@ -2481,53 +2552,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
   PyOperationRef operation;
 };
 
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
-  static constexpr const char *pyClassName = "OpResultList";
-  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
-  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
-                 intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirOperationGetNumResults(operation->get())
-                               : length,
-                  step),
-        operation(std::move(operation)) {}
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("types", [](PyOpResultList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-    c.def_prop_ro("owner", [](PyOpResultList &self) {
-      return self.operation->createOpView();
-    });
-  }
-
-private:
-  /// Give the parent CRTP class access to hook implementations below.
-  friend class Sliceable<PyOpResultList, PyOpResult>;
-
-  intptr_t getRawNumElements() {
-    operation->checkValid();
-    return mlirOperationGetNumResults(operation->get());
-  }
-
-  PyOpResult getRawElement(intptr_t index) {
-    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
-    return PyOpResult(value);
-  }
-
-  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
-    return PyOpResultList(operation, startIndex, length, step);
-  }
-
-  PyOperationRef operation;
-};
-
 /// A list of operation successors. Internally, these are stored as consecutive
 /// elements, random access is cheap. The (returned) successor list is
 /// associated with the operation whose successors these are, and thus extends
@@ -3108,20 +3132,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "result",
           [](PyOperationBase &self) {
             auto &operation = self.getOperation();
-            auto numResults = mlirOperationGetNumResults(operation);
-            if (numResults != 1) {
-              auto name = mlirIdentifierStr(mlirOperationGetName(operation));
-              throw nb::value_error(
-                  (Twine("Cannot call .result on operation ") +
-                   StringRef(name.data, name.length) + " which has " +
-                   Twine(numResults) +
-                   " results (it is only valid for operations with a "
-                   "single result)")
-                      .str()
-                      .c_str());
-            }
-            return PyOpResult(operation.getRef(),
-                              mlirOperationGetResult(operation, 0))
+            return PyOpResult(operation.getRef(), getUniqueResult(operation))
                 .maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
@@ -3218,14 +3229,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            nb::arg("walk_order") = MlirWalkPostOrder);
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
-      .def_static("create", &PyOperation::create, nb::arg("name"),
-                  nb::arg("results").none() = nb::none(),
-                  nb::arg("operands").none() = nb::none(),
-      ...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 22, 2025

✅ With the latest revision this PR passed the Python code formatter.

@github-actions
Copy link

github-actions bot commented Jan 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, thanks

This logic is in the critical path for constructing an operation from
Python. It is faster to compute this in C++ than it is in Python, and it is a
minor change to do this.

This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling
get_op_result_or_value on each element of a sequence, since the C++ code
will now do this.

Most of the diff here is simply reordering the code in IRCore.cpp.
@jpienaar jpienaar merged commit acde3f7 into llvm:main Jan 24, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants