diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h similarity index 99% rename from mlir/lib/Bindings/Python/Globals.h rename to mlir/include/mlir/Bindings/Python/Globals.h index 71a051cb3d9f5..9e3b48d7b2e68 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/include/mlir/Bindings/Python/Globals.h @@ -15,8 +15,8 @@ #include #include -#include "NanobindUtils.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h new file mode 100644 index 0000000000000..8892437ac3f95 --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h @@ -0,0 +1,470 @@ +//===- IRAttributes.h - Attribute Interfaces +//----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H +#define MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H + +#include "mlir/Bindings/Python/IRModule.h" + +namespace mlir::python { + +class PyAffineMapAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; + static constexpr const char *pyClassName = "AffineMapAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class PyIntegerSetAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; + static constexpr const char *pyClassName = "IntegerSetAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerSetAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +/// A python-wrapped dense array attribute with an element type and a derived +/// implementation class. +template +class PyDenseArrayAttribute : public PyConcreteAttribute { +public: + using PyConcreteAttribute::PyConcreteAttribute; + + /// Iterator over the integer elements of a dense array. + class PyDenseArrayIterator { + public: + PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} + + /// Return a copy of the iterator. + PyDenseArrayIterator dunderIter(); + + /// Return the next element. + EltTy dunderNext(); + + /// Bind the iterator class. + static void bind(nanobind::module_ &m); + + private: + /// The referenced dense array attribute. + PyAttribute attr; + /// The next index to read. + int nextIndex = 0; + }; + + /// Get the element at the given index. + EltTy getItem(intptr_t i); + + /// Bind the attribute class. + static void bindDerived(typename PyConcreteAttribute::ClassTy &c); + +private: + static DerivedT getAttribute(const std::vector &values, + PyMlirContextRef ctx); +}; + +/// Instantiate the python dense array classes. +struct PyDenseBoolArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; + static constexpr auto getAttribute = mlirDenseBoolArrayGet; + static constexpr auto getElement = mlirDenseBoolArrayGetElement; + static constexpr const char *pyClassName = "DenseBoolArrayAttr"; + static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseI8ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; + static constexpr auto getAttribute = mlirDenseI8ArrayGet; + static constexpr auto getElement = mlirDenseI8ArrayGetElement; + static constexpr const char *pyClassName = "DenseI8ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseI16ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; + static constexpr auto getAttribute = mlirDenseI16ArrayGet; + static constexpr auto getElement = mlirDenseI16ArrayGetElement; + static constexpr const char *pyClassName = "DenseI16ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseI32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; + static constexpr auto getAttribute = mlirDenseI32ArrayGet; + static constexpr auto getElement = mlirDenseI32ArrayGetElement; + static constexpr const char *pyClassName = "DenseI32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseI64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; + static constexpr auto getAttribute = mlirDenseI64ArrayGet; + static constexpr auto getElement = mlirDenseI64ArrayGetElement; + static constexpr const char *pyClassName = "DenseI64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseF32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; + static constexpr auto getAttribute = mlirDenseF32ArrayGet; + static constexpr auto getElement = mlirDenseF32ArrayGetElement; + static constexpr const char *pyClassName = "DenseF32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +struct PyDenseF64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; + static constexpr auto getAttribute = mlirDenseF64ArrayGet; + static constexpr auto getElement = mlirDenseF64ArrayGetElement; + static constexpr const char *pyClassName = "DenseF64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +class PyArrayAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; + static constexpr const char *pyClassName = "ArrayAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; + + class PyArrayAttributeIterator { + public: + PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} + + PyArrayAttributeIterator &dunderIter(); + + MlirAttribute dunderNext(); + + static void bind(nanobind::module_ &m); + + private: + PyAttribute attr; + int nextIndex = 0; + }; + + MlirAttribute getItem(intptr_t i); + + static void bindDerived(ClassTy &c); +}; + +/// Float Point Attribute subclass - FloatAttr. +class PyFloatAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; + static constexpr const char *pyClassName = "FloatAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +/// Integer Attribute subclass - IntegerAttr. +class PyIntegerAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; + static constexpr const char *pyClassName = "IntegerAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c); + +private: + static int64_t toPyInt(PyIntegerAttribute &self); +}; + +/// Bool Attribute subclass - BoolAttr. +class PyBoolAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; + static constexpr const char *pyClassName = "BoolAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c); +}; + +class PySymbolRefAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; + static constexpr const char *pyClassName = "SymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static MlirAttribute fromList(const std::vector &symbols, + PyMlirContext &context); + + static void bindDerived(ClassTy &c); +}; + +class PyFlatSymbolRefAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; + static constexpr const char *pyClassName = "FlatSymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c); +}; + +class PyOpaqueAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +struct nb_buffer_info { + void *ptr = nullptr; + ssize_t itemsize = 0; + ssize_t size = 0; + const char *format = nullptr; + ssize_t ndim = 0; + SmallVector shape; + SmallVector strides; + bool readonly = false; + + nb_buffer_info( + void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, + SmallVector shape_in, SmallVector strides_in, + bool readonly = false, + std::unique_ptr owned_view_in = + std::unique_ptr(nullptr, nullptr)) + : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), + readonly(readonly), owned_view(std::move(owned_view_in)) { + size = 1; + for (ssize_t i = 0; i < ndim; ++i) { + size *= shape[i]; + } + } + + explicit nb_buffer_info(Py_buffer *view) + : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + // TODO(phawkins): check for null strides + {view->strides, view->strides + view->ndim}, + view->readonly != 0, + std::unique_ptr( + view, PyBuffer_Release)) {} + + nb_buffer_info(const nb_buffer_info &) = delete; + nb_buffer_info(nb_buffer_info &&) = default; + nb_buffer_info &operator=(const nb_buffer_info &) = delete; + nb_buffer_info &operator=(nb_buffer_info &&) = default; + +private: + std::unique_ptr owned_view; +}; + +class nb_buffer : public nanobind::object { + NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + + nb_buffer_info request() const; +}; + +// TODO: Support construction of string elements. +class PyDenseElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; + static constexpr const char *pyClassName = "DenseElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseElementsAttribute + getFromList(const nanobind::list &attributes, + std::optional explicitType, + DefaultingPyMlirContext contextWrapper); + + static PyDenseElementsAttribute + getFromBuffer(const nb_buffer &array, bool signless, + const std::optional &explicitType, + std::optional> explicitShape, + DefaultingPyMlirContext contextWrapper); + + static PyDenseElementsAttribute getSplat(const PyType &shapedType, + PyAttribute &elementAttr); + + intptr_t dunderLen(); + + std::unique_ptr accessBuffer(); + + static void bindDerived(ClassTy &c); + + static PyType_Slot slots[]; + +private: + static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); + static void bf_releasebuffer(PyObject *, Py_buffer *buffer); + + static bool isUnsignedIntegerFormat(std::string_view format); + + static bool isSignedIntegerFormat(std::string_view format); + + static MlirType + getShapedType(std::optional bulkLoadElementType, + std::optional> explicitShape, + Py_buffer &view); + + static MlirAttribute getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + const std::optional> &explicitShape, + MlirContext &context); + + // There is a complication for boolean numpy arrays, as numpy represents + // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 + // booleans per byte. + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context); + + // This does the opposite transformation of + // `getBitpackedAttributeFromBooleanBuffer` + std::unique_ptr getBooleanBufferFromBitpackedAttribute(); + + template + std::unique_ptr + bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr); +}; // namespace + +class PyDenseResourceElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsADenseResourceElements; + static constexpr const char *pyClassName = "DenseResourceElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseResourceElementsAttribute + getFromBuffer(const nb_buffer &buffer, const std::string &name, + const PyType &type, std::optional alignment, + bool isMutable, DefaultingPyMlirContext contextWrapper); + + static void bindDerived(ClassTy &c); +}; + +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; + + intptr_t dunderLen(); + + bool dunderContains(const std::string &name); + + static void bindDerived(ClassTy &c); +}; + +/// Refinement of the PyDenseElementsAttribute for attributes containing +/// integer (and boolean) values. Supports element access. +class PyDenseIntElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; + static constexpr const char *pyClassName = "DenseIntElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + /// Returns the element at the given linear position. Asserts if the index + /// is out of range. + nanobind::object dunderGetItem(intptr_t pos); + + static void bindDerived(ClassTy &c); +}; + +/// Refinement of PyDenseElementsAttribute for attributes containing +/// floating-point values. Supports element access. +class PyDenseFPElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; + static constexpr const char *pyClassName = "DenseFPElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + nanobind::float_ dunderGetItem(intptr_t pos); + + static void bindDerived(ClassTy &c); +}; + +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +/// Unit Attribute subclass. Unit attributes don't have values. +class PyUnitAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; + static constexpr const char *pyClassName = "UnitAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +/// Strided layout attribute subclass. +class PyStridedLayoutAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; + static constexpr const char *pyClassName = "StridedLayoutAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; +} // namespace mlir::python + +#endif // MLIR_BINDINGS_PYTHON_IRATTRIBUTES_H diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRModule.h similarity index 96% rename from mlir/lib/Bindings/Python/IRModule.h rename to mlir/include/mlir/Bindings/Python/IRModule.h index 0cc0459ebc9a0..b3690f3babfb4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/include/mlir/Bindings/Python/IRModule.h @@ -15,16 +15,16 @@ #include #include -#include "Globals.h" -#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" +#include "mlir/Bindings/Python/Globals.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/ThreadPool.h" @@ -1268,6 +1268,56 @@ class PySymbolTable { MlirSymbolTable symbolTable; }; +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nanobind::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig); + + /// Binds the Python module objects to functions of this class. + static void bind(nanobind::module_ &m); + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m); +}; + +/// Python wrapper for MlirOpResult. +class PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c); +}; + +/// Python wrapper for MlirBlockArgument. +class PyBlockArgument : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; + static constexpr const char *pyClassName = "BlockArgument"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c); +}; + /// Custom exception that allows access to error diagnostic information. This is /// converted to the `ir.MLIRError` python exception when thrown. struct MLIRError { diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index ba9642cf2c6a2..60d21fd2f2fa0 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,12 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/IRModule.h" -namespace mlir { +namespace mlir::python { /// Shaped Type Interface - ShapedType -class PyShapedType : public python::PyConcreteType { +class PyShapedType : public PyConcreteType { public: static const IsAFunctionTy isaFunction; static constexpr const char *pyClassName = "ShapedType"; @@ -26,6 +27,367 @@ class PyShapedType : public python::PyConcreteType { void requireHasRank(); }; -} // namespace mlir +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +class PyFloatType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; + static constexpr const char *pyClassName = "FloatType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float4E2M1FNType. +class PyFloat4E2M1FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat4E2M1FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float4E2M1FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float6E2M3FNType. +class PyFloat6E2M3FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E2M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E2M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float6E3M2FNType. +class PyFloat6E3M2FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E3M2FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E3M2FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNType. +class PyFloat8E4M3FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E5M2Type. +class PyFloat8E5M2Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3Type. +class PyFloat8E4M3Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class PyFloat8E4M3B11FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E3M4Type. +class PyFloat8E3M4Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E3M4TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E3M4Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E8M0FNUType. +class PyFloat8E8M0FNUType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E8M0FNUTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E8M0FNUType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - TF32Type. +class PyTF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatTF32TypeGetTypeID; + static constexpr const char *pyClassName = "FloatTF32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + static PyVectorType get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc); +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Opaque Type subclass - OpaqueType. +class PyOpaqueType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +} // namespace mlir::python #endif // MLIR_BINDINGS_PYTHON_IRTYPES_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 1428d5ccf00f4..35cc52af3334f 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -24,11 +24,11 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" +#include "llvm/ADT/Twine.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. // clang-format on -#include "llvm/ADT/Twine.h" // Raw CAPI type casters need to be declared before use, so always include them // first. diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h similarity index 100% rename from mlir/lib/Bindings/Python/NanobindUtils.h rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index cab4219fea72b..e0a2809d4a92c 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -6,8 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "NanobindUtils.h" - #include "mlir-c/Dialect/SMT.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -15,6 +13,7 @@ #include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace nb = nanobind; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index a6499c952df6e..7b7bec4df3d00 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -13,19 +13,22 @@ #include #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/IntegerSet.h" -#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" +// clang-format off +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. +// clang-format on + namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -707,25 +710,24 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](PyAffineMap &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_static("compress_unused_symbols", - [](const nb::list &affineMaps, - DefaultingPyMlirContext context) { - SmallVector maps; - pyListToVector( - affineMaps, maps, "attempting to create an AffineMap"); - std::vector compressed(affineMaps.size()); - auto populate = [](void *result, intptr_t idx, - MlirAffineMap m) { - static_cast(result)[idx] = (m); - }; - mlirAffineMapCompressUnusedSymbols( - maps.data(), maps.size(), compressed.data(), populate); - std::vector res; - res.reserve(compressed.size()); - for (auto m : compressed) - res.emplace_back(context->getRef(), m); - return res; - }) + .def_static( + "compress_unused_symbols", + [](const nb::list &affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(), + compressed.data(), populate); + std::vector res; + res.reserve(compressed.size()); + for (auto m : compressed) + res.emplace_back(context->getRef(), m); + return res; + }) .def_prop_ro( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index f2eafa7c2d45c..d3370b4a5300f 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -12,12 +12,12 @@ #include #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/IRAttributes.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -122,64 +122,6 @@ subsequent processing. )"; namespace { - -struct nb_buffer_info { - void *ptr = nullptr; - ssize_t itemsize = 0; - ssize_t size = 0; - const char *format = nullptr; - ssize_t ndim = 0; - SmallVector shape; - SmallVector strides; - bool readonly = false; - - nb_buffer_info( - void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, - SmallVector shape_in, SmallVector strides_in, - bool readonly = false, - std::unique_ptr owned_view_in = - std::unique_ptr(nullptr, nullptr)) - : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), - shape(std::move(shape_in)), strides(std::move(strides_in)), - readonly(readonly), owned_view(std::move(owned_view_in)) { - size = 1; - for (ssize_t i = 0; i < ndim; ++i) { - size *= shape[i]; - } - } - - explicit nb_buffer_info(Py_buffer *view) - : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, - // TODO(phawkins): check for null strides - {view->strides, view->strides + view->ndim}, - view->readonly != 0, - std::unique_ptr( - view, PyBuffer_Release)) {} - - nb_buffer_info(const nb_buffer_info &) = delete; - nb_buffer_info(nb_buffer_info &&) = default; - nb_buffer_info &operator=(const nb_buffer_info &) = delete; - nb_buffer_info &operator=(nb_buffer_info &&) = default; - -private: - std::unique_ptr owned_view; -}; - -class nb_buffer : public nb::object { - NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); - - nb_buffer_info request() const { - int flags = PyBUF_STRIDES | PyBUF_FORMAT; - auto *view = new Py_buffer(); - if (PyObject_GetBuffer(ptr(), view, flags) != 0) { - delete view; - throw nb::python_error(); - } - return nb_buffer_info(view); - } -}; - template struct nb_format_descriptor {}; @@ -236,47 +178,6 @@ static MlirStringRef toMlirStringRef(const nb::bytes &s) { return mlirStringRefCreate(static_cast(s.data()), s.size()); } -class PyAffineMapAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; - static constexpr const char *pyClassName = "AffineMapAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirAffineMapAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyAffineMap &affineMap) { - MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); - return PyAffineMapAttribute(affineMap.getContext(), attr); - }, - nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_prop_ro("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); - } -}; - -class PyIntegerSetAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; - static constexpr const char *pyClassName = "IntegerSetAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIntegerSetAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyIntegerSet &integerSet) { - MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); - return PyIntegerSetAttribute(integerSet.getContext(), attr); - }, - nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); - } -}; - template static T pyTryCast(nb::handle object) { try { @@ -294,1012 +195,879 @@ static T pyTryCast(nb::handle object) { } } -/// A python-wrapped dense array attribute with an element type and a derived -/// implementation class. -template -class PyDenseArrayAttribute : public PyConcreteAttribute { -public: - using PyConcreteAttribute::PyConcreteAttribute; - - /// Iterator over the integer elements of a dense array. - class PyDenseArrayIterator { - public: - PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} - - /// Return a copy of the iterator. - PyDenseArrayIterator dunderIter() { return *this; } - - /// Return the next element. - EltTy dunderNext() { - // Throw if the index has reached the end. - if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) - throw nb::stop_iteration(); - return DerivedT::getElement(attr.get(), nextIndex++); - } - - /// Bind the iterator class. - static void bind(nb::module_ &m) { - nb::class_(m, DerivedT::pyIteratorName) - .def("__iter__", &PyDenseArrayIterator::dunderIter) - .def("__next__", &PyDenseArrayIterator::dunderNext); - } - - private: - /// The referenced dense array attribute. - PyAttribute attr; - /// The next index to read. - int nextIndex = 0; - }; - - /// Get the element at the given index. - EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } - - /// Bind the attribute class. - static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { - // Bind the constructor. - if constexpr (std::is_same_v) { - c.def_static( - "get", - [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { - std::vector values; - for (nb::handle py_value : py_values) { - int is_true = PyObject_IsTrue(py_value.ptr()); - if (is_true < 0) { - throw nb::python_error(); - } - values.push_back(is_true); - } - return getAttribute(values, ctx->getRef()); - }, - nb::arg("values"), nb::arg("context").none() = nb::none(), - "Gets a uniqued dense array attribute"); - } else { - c.def_static( - "get", - [](const std::vector &values, DefaultingPyMlirContext ctx) { - return getAttribute(values, ctx->getRef()); - }, - nb::arg("values"), nb::arg("context").none() = nb::none(), - "Gets a uniqued dense array attribute"); - } - // Bind the array methods. - c.def("__getitem__", [](DerivedT &arr, intptr_t i) { - if (i >= mlirDenseArrayGetNumElements(arr)) - throw nb::index_error("DenseArray index out of range"); - return arr.getItem(i); - }); - c.def("__len__", [](const DerivedT &arr) { - return mlirDenseArrayGetNumElements(arr); - }); - c.def("__iter__", - [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, const nb::list &extras) { - std::vector values; - intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); - values.reserve(numOldElements + nb::len(extras)); - for (intptr_t i = 0; i < numOldElements; ++i) - values.push_back(arr.getItem(i)); - for (nb::handle attr : extras) - values.push_back(pyTryCast(attr)); - return getAttribute(values, arr.getContext()); - }); - } +} // namespace -private: - static DerivedT getAttribute(const std::vector &values, - PyMlirContextRef ctx) { - if constexpr (std::is_same_v) { - std::vector intValues(values.begin(), values.end()); - MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), - intValues.data()); - return DerivedT(ctx, attr); - } else { - MlirAttribute attr = - DerivedT::getAttribute(ctx->get(), values.size(), values.data()); - return DerivedT(ctx, attr); - } - } -}; +namespace mlir::python { +void PyAffineMapAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAffineMap &affineMap) { + MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); + return PyAffineMapAttribute(affineMap.getContext(), attr); + }, + nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_prop_ro("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); +} -/// Instantiate the python dense array classes. -struct PyDenseBoolArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; - static constexpr auto getAttribute = mlirDenseBoolArrayGet; - static constexpr auto getElement = mlirDenseBoolArrayGetElement; - static constexpr const char *pyClassName = "DenseBoolArrayAttr"; - static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseI8ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; - static constexpr auto getAttribute = mlirDenseI8ArrayGet; - static constexpr auto getElement = mlirDenseI8ArrayGetElement; - static constexpr const char *pyClassName = "DenseI8ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseI16ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; - static constexpr auto getAttribute = mlirDenseI16ArrayGet; - static constexpr auto getElement = mlirDenseI16ArrayGetElement; - static constexpr const char *pyClassName = "DenseI16ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseI32ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; - static constexpr auto getAttribute = mlirDenseI32ArrayGet; - static constexpr auto getElement = mlirDenseI32ArrayGetElement; - static constexpr const char *pyClassName = "DenseI32ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseI64ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; - static constexpr auto getAttribute = mlirDenseI64ArrayGet; - static constexpr auto getElement = mlirDenseI64ArrayGetElement; - static constexpr const char *pyClassName = "DenseI64ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseF32ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; - static constexpr auto getAttribute = mlirDenseF32ArrayGet; - static constexpr auto getElement = mlirDenseF32ArrayGetElement; - static constexpr const char *pyClassName = "DenseF32ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; -struct PyDenseF64ArrayAttribute - : public PyDenseArrayAttribute { - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; - static constexpr auto getAttribute = mlirDenseF64ArrayGet; - static constexpr auto getElement = mlirDenseF64ArrayGetElement; - static constexpr const char *pyClassName = "DenseF64ArrayAttr"; - static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; - using PyDenseArrayAttribute::PyDenseArrayAttribute; -}; +void PyIntegerSetAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyIntegerSet &integerSet) { + MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); + return PyIntegerSetAttribute(integerSet.getContext(), attr); + }, + nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); +} -class PyArrayAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; - static constexpr const char *pyClassName = "ArrayAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirArrayAttrGetTypeID; - - class PyArrayAttributeIterator { - public: - PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} - - PyArrayAttributeIterator &dunderIter() { return *this; } - - MlirAttribute dunderNext() { - // TODO: Throw is an inefficient way to stop iteration. - if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) - throw nb::stop_iteration(); - return mlirArrayAttrGetElement(attr.get(), nextIndex++); - } +template +typename PyDenseArrayAttribute::PyDenseArrayIterator +PyDenseArrayAttribute::PyDenseArrayIterator::dunderIter() { + return *this; +} - static void bind(nb::module_ &m) { - nb::class_(m, "ArrayAttributeIterator") - .def("__iter__", &PyArrayAttributeIterator::dunderIter) - .def("__next__", &PyArrayAttributeIterator::dunderNext); - } +template +EltTy PyDenseArrayAttribute::PyDenseArrayIterator::dunderNext() { + // Throw if the index has reached the end. + if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) + throw nb::stop_iteration(); + return DerivedT::getElement(attr.get(), nextIndex++); +} - private: - PyAttribute attr; - int nextIndex = 0; - }; +template +void PyDenseArrayAttribute::PyDenseArrayIterator::bind( + nb::module_ &m) { + nb::class_(m, DerivedT::pyIteratorName) + .def("__iter__", &PyDenseArrayIterator::dunderIter) + .def("__next__", &PyDenseArrayIterator::dunderNext); +} - MlirAttribute getItem(intptr_t i) { - return mlirArrayAttrGetElement(*this, i); - } +template +EltTy PyDenseArrayAttribute::getItem(intptr_t i) { + return DerivedT::getElement(*this, i); +} - static void bindDerived(ClassTy &c) { +template +void PyDenseArrayAttribute::bindDerived( + typename PyConcreteAttribute::ClassTy &c) { + // Bind the constructor. + if constexpr (std::is_same_v) { c.def_static( "get", - [](const nb::list &attributes, DefaultingPyMlirContext context) { - SmallVector mlirAttributes; - mlirAttributes.reserve(nb::len(attributes)); - for (auto attribute : attributes) { - mlirAttributes.push_back(pyTryCast(attribute)); + [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { + std::vector values; + for (nb::handle py_value : py_values) { + int is_true = PyObject_IsTrue(py_value.ptr()); + if (is_true < 0) { + throw nb::python_error(); + } + values.push_back(is_true); } - MlirAttribute attr = mlirArrayAttrGet( - context->get(), mlirAttributes.size(), mlirAttributes.data()); - return PyArrayAttribute(context->getRef(), attr); - }, - nb::arg("attributes"), nb::arg("context").none() = nb::none(), - "Gets a uniqued Array attribute"); - c.def("__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw nb::index_error("ArrayAttribute index out of range"); - return arr.getItem(i); - }) - .def("__len__", - [](const PyArrayAttribute &arr) { - return mlirArrayAttrGetNumElements(arr); - }) - .def("__iter__", [](const PyArrayAttribute &arr) { - return PyArrayAttributeIterator(arr); - }); - c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) { - std::vector attributes; - intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); - attributes.reserve(numOldElements + nb::len(extras)); - for (intptr_t i = 0; i < numOldElements; ++i) - attributes.push_back(arr.getItem(i)); - for (nb::handle attr : extras) - attributes.push_back(pyTryCast(attr)); - MlirAttribute arrayAttr = mlirArrayAttrGet( - arr.getContext()->get(), attributes.size(), attributes.data()); - return PyArrayAttribute(arr.getContext(), arrayAttr); - }); - } -}; - -/// Float Point Attribute subclass - FloatAttr. -class PyFloatAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; - static constexpr const char *pyClassName = "FloatAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloatAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, double value, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - if (mlirAttributeIsNull(attr)) - throw MLIRError("Invalid attribute", errors.take()); - return PyFloatAttribute(type.getContext(), attr); - }, - nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), - "Gets an uniqued float point attribute associated to a type"); - c.def_static( - "get_f32", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF32TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); + return getAttribute(values, ctx->getRef()); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets an uniqued float point attribute associated to a f32 type"); - c.def_static( - "get_f64", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF64TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets an uniqued float point attribute associated to a f64 type"); - c.def_prop_ro("value", mlirFloatAttrGetValueDouble, - "Returns the value of the float attribute"); - c.def("__float__", mlirFloatAttrGetValueDouble, - "Converts the value of the float attribute to a Python float"); - } -}; - -/// Integer Attribute subclass - IntegerAttr. -class PyIntegerAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; - static constexpr const char *pyClassName = "IntegerAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } else { c.def_static( "get", - [](PyType &type, int64_t value) { - MlirAttribute attr = mlirIntegerAttrGet(type, value); - return PyIntegerAttribute(type.getContext(), attr); + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); }, - nb::arg("type"), nb::arg("value"), - "Gets an uniqued integer attribute associated to a type"); - c.def_prop_ro("value", toPyInt, - "Returns the value of the integer attribute"); - c.def("__int__", toPyInt, - "Converts the value of the integer attribute to a Python int"); - c.def_prop_ro_static("static_typeid", - [](nb::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); } + // Bind the array methods. + c.def("__getitem__", [](DerivedT &arr, intptr_t i) { + if (i >= mlirDenseArrayGetNumElements(arr)) + throw nb::index_error("DenseArray index out of range"); + return arr.getItem(i); + }); + c.def("__len__", + [](const DerivedT &arr) { return mlirDenseArrayGetNumElements(arr); }); + c.def("__iter__", + [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); + c.def("__add__", [](DerivedT &arr, const nb::list &extras) { + std::vector values; + intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); + values.reserve(numOldElements + nb::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + values.push_back(arr.getItem(i)); + for (nb::handle attr : extras) + values.push_back(pyTryCast(attr)); + return getAttribute(values, arr.getContext()); + }); +} -private: - static int64_t toPyInt(PyIntegerAttribute &self) { - MlirType type = mlirAttributeGetType(self); - if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) - return mlirIntegerAttrGetValueInt(self); - if (mlirIntegerTypeIsSigned(type)) - return mlirIntegerAttrGetValueSInt(self); - return mlirIntegerAttrGetValueUInt(self); +template +DerivedT PyDenseArrayAttribute::getAttribute( + const std::vector &values, PyMlirContextRef ctx) { + if constexpr (std::is_same_v) { + std::vector intValues(values.begin(), values.end()); + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), intValues.size(), intValues.data()); + return DerivedT(ctx, attr); + } else { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return DerivedT(ctx, attr); } -}; +} -/// Bool Attribute subclass - BoolAttr. -class PyBoolAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; - static constexpr const char *pyClassName = "BoolAttr"; - using PyConcreteAttribute::PyConcreteAttribute; +PyArrayAttribute::PyArrayAttributeIterator & +PyArrayAttribute::PyArrayAttributeIterator::dunderIter() { + return *this; +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](bool value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirBoolAttrGet(context->get(), value); - return PyBoolAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets an uniqued bool attribute"); - c.def_prop_ro("value", mlirBoolAttrGetValue, - "Returns the value of the bool attribute"); - c.def("__bool__", mlirBoolAttrGetValue, - "Converts the value of the bool attribute to a Python bool"); - } -}; +MlirAttribute PyArrayAttribute::PyArrayAttributeIterator::dunderNext() { + // TODO: Throw is an inefficient way to stop iteration. + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) + throw nb::stop_iteration(); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); +} -class PySymbolRefAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; - static constexpr const char *pyClassName = "SymbolRefAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static MlirAttribute fromList(const std::vector &symbols, - PyMlirContext &context) { - if (symbols.empty()) - throw std::runtime_error("SymbolRefAttr must be composed of at least " - "one symbol."); - MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); - SmallVector referenceAttrs; - for (size_t i = 1; i < symbols.size(); ++i) { - referenceAttrs.push_back( - mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); - } - return mlirSymbolRefAttrGet(context.get(), rootSymbol, - referenceAttrs.size(), referenceAttrs.data()); - } +void PyArrayAttribute::PyArrayAttributeIterator::bind(nb::module_ &m) { + nb::class_(m, "ArrayAttributeIterator") + .def("__iter__", &PyArrayAttributeIterator::dunderIter) + .def("__next__", &PyArrayAttributeIterator::dunderNext); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::vector &symbols, - DefaultingPyMlirContext context) { - return PySymbolRefAttribute::fromList(symbols, context.resolve()); - }, - nb::arg("symbols"), nb::arg("context").none() = nb::none(), - "Gets a uniqued SymbolRef attribute from a list of symbol names"); - c.def_prop_ro( - "value", - [](PySymbolRefAttribute &self) { - std::vector symbols = { - unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; - for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); - ++i) - symbols.push_back( - unwrap(mlirSymbolRefAttrGetRootReference( - mlirSymbolRefAttrGetNestedReference(self, i))) - .str()); - return symbols; - }, - "Returns the value of the SymbolRef attribute as a list[str]"); - } -}; +MlirAttribute PyArrayAttribute::getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); +} -class PyFlatSymbolRefAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; - static constexpr const char *pyClassName = "FlatSymbolRefAttr"; - using PyConcreteAttribute::PyConcreteAttribute; +void PyArrayAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const nb::list &attributes, DefaultingPyMlirContext context) { + SmallVector mlirAttributes; + mlirAttributes.reserve(nb::len(attributes)); + for (auto attribute : attributes) { + mlirAttributes.push_back(pyTryCast(attribute)); + } + MlirAttribute attr = mlirArrayAttrGet( + context->get(), mlirAttributes.size(), mlirAttributes.data()); + return PyArrayAttribute(context->getRef(), attr); + }, + nb::arg("attributes"), nb::arg("context").none() = nb::none(), + "Gets a uniqued Array attribute"); + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw nb::index_error("ArrayAttribute index out of range"); + return arr.getItem(i); + }) + .def("__len__", + [](const PyArrayAttribute &arr) { + return mlirArrayAttrGetNumElements(arr); + }) + .def("__iter__", [](const PyArrayAttribute &arr) { + return PyArrayAttributeIterator(arr); + }); + c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) { + std::vector attributes; + intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); + attributes.reserve(numOldElements + nb::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + attributes.push_back(arr.getItem(i)); + for (nb::handle attr : extras) + attributes.push_back(pyTryCast(attr)); + MlirAttribute arrayAttr = mlirArrayAttrGet( + arr.getContext()->get(), attributes.size(), attributes.data()); + return PyArrayAttribute(arr.getContext(), arrayAttr); + }); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); - return PyFlatSymbolRefAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued FlatSymbolRef attribute"); - c.def_prop_ro( - "value", - [](PyFlatSymbolRefAttribute &self) { - MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the value of the FlatSymbolRef attribute as a string"); - } -}; +void PyFloatAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, double value, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); + return PyFloatAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), + "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_f32", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF32TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued float point attribute associated to a f32 type"); + c.def_static( + "get_f64", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF64TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued float point attribute associated to a f64 type"); + c.def_prop_ro("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); + c.def("__float__", mlirFloatAttrGetValueDouble, + "Converts the value of the float attribute to a Python float"); +} -class PyOpaqueAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; - static constexpr const char *pyClassName = "OpaqueAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirOpaqueAttrGetTypeID; +void PyIntegerAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, int64_t value) { + MlirAttribute attr = mlirIntegerAttrGet(type, value); + return PyIntegerAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), + "Gets an uniqued integer attribute associated to a type"); + c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute"); + c.def("__int__", toPyInt, + "Converts the value of the integer attribute to a Python int"); + c.def_prop_ro_static("static_typeid", + [](nb::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &dialectNamespace, const nb_buffer &buffer, - PyType &type, DefaultingPyMlirContext context) { - const nb_buffer_info bufferInfo = buffer.request(); - intptr_t bufferSize = bufferInfo.size; - MlirAttribute attr = mlirOpaqueAttrGet( - context->get(), toMlirStringRef(dialectNamespace), bufferSize, - static_cast(bufferInfo.ptr), type); - return PyOpaqueAttribute(context->getRef(), attr); - }, - nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), - nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); - c.def_prop_ro( - "dialect_namespace", - [](PyOpaqueAttribute &self) { - MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque attribute as a string"); - c.def_prop_ro( - "data", - [](PyOpaqueAttribute &self) { - MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return nb::bytes(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaqued attributes as `bytes`"); - } -}; +int64_t PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); +} -class PyStringAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; - static constexpr const char *pyClassName = "StringAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirStringAttrGetTypeID; +void PyBoolAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](bool value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirBoolAttrGet(context->get(), value); + return PyBoolAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued bool attribute"); + c.def_prop_ro("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); + c.def("__bool__", mlirBoolAttrGetValue, + "Converts the value of the bool attribute to a Python bool"); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get", - [](const nb::bytes &value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get_typed", - [](PyType &type, const std::string &value) { - MlirAttribute attr = - mlirStringAttrTypedGet(type, toMlirStringRef(value)); - return PyStringAttribute(type.getContext(), attr); - }, - nb::arg("type"), nb::arg("value"), - "Gets a uniqued string attribute associated to a type"); - c.def_prop_ro( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); - c.def_prop_ro( - "value_bytes", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::bytes(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute as `bytes`"); +MlirAttribute +PySymbolRefAttribute::fromList(const std::vector &symbols, + PyMlirContext &context) { + if (symbols.empty()) + throw std::runtime_error("SymbolRefAttr must be composed of at least " + "one symbol."); + MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); + SmallVector referenceAttrs; + for (size_t i = 1; i < symbols.size(); ++i) { + referenceAttrs.push_back( + mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); } -}; - -// TODO: Support construction of string elements. -class PyDenseElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; - static constexpr const char *pyClassName = "DenseElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static PyDenseElementsAttribute - getFromList(const nb::list &attributes, std::optional explicitType, - DefaultingPyMlirContext contextWrapper) { - const size_t numAttributes = nb::len(attributes); - if (numAttributes == 0) - throw nb::value_error("Attributes list must be non-empty."); - - MlirType shapedType; - if (explicitType) { - if ((!mlirTypeIsAShaped(*explicitType) || - !mlirShapedTypeHasStaticShape(*explicitType))) { - - std::string message; - llvm::raw_string_ostream os(message); - os << "Expected a static ShapedType for the shaped_type parameter: " - << nb::cast(nb::repr(nb::cast(*explicitType))); - throw nb::value_error(message.c_str()); - } - shapedType = *explicitType; - } else { - SmallVector shape = {static_cast(numAttributes)}; - shapedType = mlirRankedTensorTypeGet( - shape.size(), shape.data(), - mlirAttributeGetType(pyTryCast(attributes[0])), - mlirAttributeGetNull()); - } + return mlirSymbolRefAttrGet(context.get(), rootSymbol, referenceAttrs.size(), + referenceAttrs.data()); +} - SmallVector mlirAttributes; - mlirAttributes.reserve(numAttributes); - for (const nb::handle &attribute : attributes) { - MlirAttribute mlirAttribute = pyTryCast(attribute); - MlirType attrType = mlirAttributeGetType(mlirAttribute); - mlirAttributes.push_back(mlirAttribute); - - if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { - std::string message; - llvm::raw_string_ostream os(message); - os << "All attributes must be of the same type and match " - << "the type parameter: expected=" - << nb::cast(nb::repr(nb::cast(shapedType))) - << ", but got=" - << nb::cast(nb::repr(nb::cast(attrType))); - throw nb::value_error(message.c_str()); - } - } +void PySymbolRefAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &symbols, + DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromList(symbols, context.resolve()); + }, + nb::arg("symbols"), nb::arg("context").none() = nb::none(), + "Gets a uniqued SymbolRef attribute from a list of symbol names"); + c.def_prop_ro( + "value", + [](PySymbolRefAttribute &self) { + std::vector symbols = { + unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; + for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); ++i) + symbols.push_back( + unwrap(mlirSymbolRefAttrGetRootReference( + mlirSymbolRefAttrGetNestedReference(self, i))) + .str()); + return symbols; + }, + "Returns the value of the SymbolRef attribute as a list[str]"); +} - MlirAttribute elements = mlirDenseElementsAttrGet( - shapedType, mlirAttributes.size(), mlirAttributes.data()); +void PyFlatSymbolRefAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); + return PyFlatSymbolRefAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued FlatSymbolRef attribute"); + c.def_prop_ro( + "value", + [](PyFlatSymbolRefAttribute &self) { + MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the value of the FlatSymbolRef attribute as a string"); +} - return PyDenseElementsAttribute(contextWrapper->getRef(), elements); +nb_buffer_info nb_buffer::request() const { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + auto *view = new Py_buffer(); + if (PyObject_GetBuffer(ptr(), view, flags) != 0) { + delete view; + throw nb::python_error(); } + return nb_buffer_info(view); +} - static PyDenseElementsAttribute - getFromBuffer(const nb_buffer &array, bool signless, - const std::optional &explicitType, - std::optional> explicitShape, - DefaultingPyMlirContext contextWrapper) { - // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_ND; - if (!explicitType) { - flags |= PyBUF_FORMAT; - } - Py_buffer view; - if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { - throw nb::python_error(); - } - auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); +void PyOpaqueAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &dialectNamespace, const nb_buffer &buffer, + PyType &type, DefaultingPyMlirContext context) { + const nb_buffer_info bufferInfo = buffer.request(); + intptr_t bufferSize = bufferInfo.size; + MlirAttribute attr = mlirOpaqueAttrGet( + context->get(), toMlirStringRef(dialectNamespace), bufferSize, + static_cast(bufferInfo.ptr), type); + return PyOpaqueAttribute(context->getRef(), attr); + }, + nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), + nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + c.def_prop_ro( + "dialect_namespace", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque attribute as a string"); + c.def_prop_ro( + "data", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetData(self); + return nb::bytes(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaqued attributes as `bytes`"); +} - MlirContext context = contextWrapper->get(); - MlirAttribute attr = getAttributeFromBuffer( - view, signless, explicitType, std::move(explicitShape), context); - if (mlirAttributeIsNull(attr)) { - throw std::invalid_argument( - "DenseElementsAttr could not be constructed from the given buffer. " - "This may mean that the Python buffer layout does not match that " - "MLIR expected layout and is a bug."); - } - return PyDenseElementsAttribute(contextWrapper->getRef(), attr); - } +void PyStringAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](const nb::bytes &value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, const std::string &value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type, toMlirStringRef(value)); + return PyStringAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), + "Gets a uniqued string attribute associated to a type"); + c.def_prop_ro( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); + c.def_prop_ro( + "value_bytes", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::bytes(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute as `bytes`"); +} - static PyDenseElementsAttribute getSplat(const PyType &shapedType, - PyAttribute &elementAttr) { - auto contextWrapper = - PyMlirContext::forContext(mlirTypeGetContext(shapedType)); - if (!mlirAttributeIsAInteger(elementAttr) && - !mlirAttributeIsAFloat(elementAttr)) { - std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); - throw nb::value_error(message.c_str()); - } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { - std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; - message.append(nb::cast(nb::repr(nb::cast(shapedType)))); +PyDenseElementsAttribute +PyDenseElementsAttribute::getFromList(const nb::list &attributes, + std::optional explicitType, + DefaultingPyMlirContext contextWrapper) { + const size_t numAttributes = nb::len(attributes); + if (numAttributes == 0) + throw nb::value_error("Attributes list must be non-empty."); + + MlirType shapedType; + if (explicitType) { + if ((!mlirTypeIsAShaped(*explicitType) || + !mlirShapedTypeHasStaticShape(*explicitType))) { + + std::string message; + llvm::raw_string_ostream os(message); + os << "Expected a static ShapedType for the shaped_type parameter: " + << nb::cast(nb::repr(nb::cast(*explicitType))); throw nb::value_error(message.c_str()); } - MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); - MlirType attrType = mlirAttributeGetType(elementAttr); - if (!mlirTypeEqual(shapedElementType, attrType)) { - std::string message = - "Shaped element type and attribute type must be equal: shaped="; - message.append(nb::cast(nb::repr(nb::cast(shapedType)))); - message.append(", element="); - message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + shapedType = *explicitType; + } else { + SmallVector shape = {static_cast(numAttributes)}; + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), + mlirAttributeGetType(pyTryCast(attributes[0])), + mlirAttributeGetNull()); + } + + SmallVector mlirAttributes; + mlirAttributes.reserve(numAttributes); + for (const nb::handle &attribute : attributes) { + MlirAttribute mlirAttribute = pyTryCast(attribute); + MlirType attrType = mlirAttributeGetType(mlirAttribute); + mlirAttributes.push_back(mlirAttribute); + + if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { + std::string message; + llvm::raw_string_ostream os(message); + os << "All attributes must be of the same type and match " + << "the type parameter: expected=" + << nb::cast(nb::repr(nb::cast(shapedType))) + << ", but got=" << nb::cast(nb::repr(nb::cast(attrType))); throw nb::value_error(message.c_str()); } + } + + MlirAttribute elements = mlirDenseElementsAttrGet( + shapedType, mlirAttributes.size(), mlirAttributes.data()); + + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); +} - MlirAttribute elements = - mlirDenseElementsAttrSplatGet(shapedType, elementAttr); - return PyDenseElementsAttribute(contextWrapper->getRef(), elements); +PyDenseElementsAttribute PyDenseElementsAttribute::getFromBuffer( + const nb_buffer &array, bool signless, + const std::optional &explicitType, + std::optional> explicitShape, + DefaultingPyMlirContext contextWrapper) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_ND; + if (!explicitType) { + flags |= PyBUF_FORMAT; + } + Py_buffer view; + if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { + throw nb::python_error(); } + auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); - intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } + MlirContext context = contextWrapper->get(); + MlirAttribute attr = getAttributeFromBuffer( + view, signless, explicitType, std::move(explicitShape), context); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); +} - std::unique_ptr accessBuffer() { - MlirType shapedType = mlirAttributeGetType(*this); - MlirType elementType = mlirShapedTypeGetElementType(shapedType); - std::string format; +PyDenseElementsAttribute +PyDenseElementsAttribute::getSplat(const PyType &shapedType, + PyAttribute &elementAttr) { + auto contextWrapper = + PyMlirContext::forContext(mlirTypeGetContext(shapedType)); + if (!mlirAttributeIsAInteger(elementAttr) && + !mlirAttributeIsAFloat(elementAttr)) { + std::string message = "Illegal element type for DenseElementsAttr: "; + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); + } + if (!mlirTypeIsAShaped(shapedType) || + !mlirShapedTypeHasStaticShape(shapedType)) { + std::string message = + "Expected a static ShapedType for the shaped_type parameter: "; + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + throw nb::value_error(message.c_str()); + } + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); + MlirType attrType = mlirAttributeGetType(elementAttr); + if (!mlirTypeEqual(shapedElementType, attrType)) { + std::string message = + "Shaped element type and attribute type must be equal: shaped="; + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + message.append(", element="); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); + } - if (mlirTypeIsAF32(elementType)) { - // f32 - return bufferInfo(shapedType); - } - if (mlirTypeIsAF64(elementType)) { - // f64 - return bufferInfo(shapedType); + MlirAttribute elements = + mlirDenseElementsAttrSplatGet(shapedType, elementAttr); + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); +} + +intptr_t PyDenseElementsAttribute::dunderLen() { + return mlirElementsAttrGetNumElements(*this); +} + +std::unique_ptr PyDenseElementsAttribute::accessBuffer() { + MlirType shapedType = mlirAttributeGetType(*this); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + std::string format; + + if (mlirTypeIsAF32(elementType)) { + // f32 + return bufferInfo(shapedType); + } + if (mlirTypeIsAF64(elementType)) { + // f64 + return bufferInfo(shapedType); + } + if (mlirTypeIsAF16(elementType)) { + // f16 + return bufferInfo(shapedType, "e"); + } + if (mlirTypeIsAIndex(elementType)) { + // Same as IndexType::kInternalStorageBitWidth + return bufferInfo(shapedType); + } + if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i32 + return bufferInfo(shapedType); } - if (mlirTypeIsAF16(elementType)) { - // f16 - return bufferInfo(shapedType, "e"); + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i32 + return bufferInfo(shapedType); } - if (mlirTypeIsAIndex(elementType)) { - // Same as IndexType::kInternalStorageBitWidth + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i64 return bufferInfo(shapedType); } - if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 32) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i32 - return bufferInfo(shapedType); - } - if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i32 - return bufferInfo(shapedType); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 64) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i64 - return bufferInfo(shapedType); - } - if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i64 - return bufferInfo(shapedType); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 8) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i8 - return bufferInfo(shapedType); - } - if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i8 - return bufferInfo(shapedType); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 16) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i16 - return bufferInfo(shapedType); - } - if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i16 - return bufferInfo(shapedType); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 1) { - // i1 / bool - // We can not send the buffer directly back to Python, because the i1 - // values are bitpacked within MLIR. We call numpy's unpackbits function - // to convert the bytes. - return getBooleanBufferFromBitpackedAttribute(); + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i64 + return bufferInfo(shapedType); } - - // TODO: Currently crashes the program. - // Reported as https://github.com/pybind/pybind11/issues/3336 - throw std::invalid_argument( - "unsupported data type for conversion to Python buffer"); + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 8) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i8 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i8 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 16) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i16 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i16 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 1) { + // i1 / bool + // We can not send the buffer directly back to Python, because the i1 + // values are bitpacked within MLIR. We call numpy's unpackbits function + // to convert the bytes. + return getBooleanBufferFromBitpackedAttribute(); } - static void bindDerived(ClassTy &c) { + // TODO: Currently crashes the program. + // Reported as https://github.com/pybind/pybind11/issues/3336 + throw std::invalid_argument( + "unsupported data type for conversion to Python buffer"); +} + +void PyDenseElementsAttribute::bindDerived(ClassTy &c) { #if PY_VERSION_HEX < 0x03090000 - PyTypeObject *tp = reinterpret_cast(c.ptr()); - tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; - tp->tp_as_buffer->bf_releasebuffer = - PyDenseElementsAttribute::bf_releasebuffer; + PyTypeObject *tp = reinterpret_cast(c.ptr()); + tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; + tp->tp_as_buffer->bf_releasebuffer = + PyDenseElementsAttribute::bf_releasebuffer; #endif - c.def("__len__", &PyDenseElementsAttribute::dunderLen) - .def_static("get", PyDenseElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("signless") = true, - nb::arg("type").none() = nb::none(), - nb::arg("shape").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kDenseElementsAttrGetDocstring) - .def_static("get", PyDenseElementsAttribute::getFromList, - nb::arg("attrs"), nb::arg("type").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kDenseElementsAttrGetFromListDocstring) - .def_static("get_splat", PyDenseElementsAttribute::getSplat, - nb::arg("shaped_type"), nb::arg("element_attr"), - "Gets a DenseElementsAttr where all values are the same") - .def_prop_ro("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def("get_splat_value", [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw nb::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }); - } - - static PyType_Slot slots[]; + c.def("__len__", &PyDenseElementsAttribute::dunderLen) + .def_static("get", PyDenseElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("signless") = true, + nb::arg("type").none() = nb::none(), + nb::arg("shape").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kDenseElementsAttrGetDocstring) + .def_static("get", PyDenseElementsAttribute::getFromList, + nb::arg("attrs"), nb::arg("type").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kDenseElementsAttrGetFromListDocstring) + .def_static("get_splat", PyDenseElementsAttribute::getSplat, + nb::arg("shaped_type"), nb::arg("element_attr"), + "Gets a DenseElementsAttr where all values are the same") + .def_prop_ro("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }); +} -private: - static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); - static void bf_releasebuffer(PyObject *, Py_buffer *buffer); +bool PyDenseElementsAttribute::isUnsignedIntegerFormat( + std::string_view format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'I' || code == 'B' || code == 'H' || code == 'L' || + code == 'Q'; +} - static bool isUnsignedIntegerFormat(std::string_view format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'I' || code == 'B' || code == 'H' || code == 'L' || - code == 'Q'; - } +bool PyDenseElementsAttribute::isSignedIntegerFormat(std::string_view format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'i' || code == 'b' || code == 'h' || code == 'l' || + code == 'q'; +} - static bool isSignedIntegerFormat(std::string_view format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'i' || code == 'b' || code == 'h' || code == 'l' || - code == 'q'; +MlirType PyDenseElementsAttribute::getShapedType( + std::optional bulkLoadElementType, + std::optional> explicitShape, Py_buffer &view) { + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); } - static MlirType - getShapedType(std::optional bulkLoadElementType, - std::optional> explicitShape, - Py_buffer &view) { - SmallVector shape; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { if (explicitShape) { - shape.append(explicitShape->begin(), explicitShape->end()); - } else { - shape.append(view.shape, view.shape + view.ndim); + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); } - - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); - } - return *bulkLoadElementType; - } - MlirAttribute encodingAttr = mlirAttributeGetNull(); - return mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); + return *bulkLoadElementType; } + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); +} - static MlirAttribute getAttributeFromBuffer( - Py_buffer &view, bool signless, std::optional explicitType, - const std::optional> &explicitShape, - MlirContext &context) { - // Detect format codes that are suitable for bulk loading. This includes - // all byte aligned integer and floating point types up to 8 bytes. - // Notably, this excludes exotics types which do not have a direct - // representation in the buffer protocol (i.e. complex, etc). - std::optional bulkLoadElementType; - if (explicitType) { - bulkLoadElementType = *explicitType; - } else { - std::string_view format(view.format); - if (format == "f") { - // f32 - assert(view.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (format == "d") { - // f64 - assert(view.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (format == "e") { - // f16 - assert(view.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (format == "?") { - // i1 - // The i1 type needs to be bit-packed, so we will handle it separately - return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, - context); - } else if (isSignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (view.itemsize == 8) { - // i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (view.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } +MlirAttribute PyDenseElementsAttribute::getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + const std::optional> &explicitShape, + MlirContext &context) { + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes exotics types which do not have a direct + // representation in the buffer protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (format == "?") { + // i1 + // The i1 type needs to be bit-packed, so we will handle it separately + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, + context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); } - if (!bulkLoadElementType) { - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - std::string(format)); + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); } } - - MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); - return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); - } - - // There is a complication for boolean numpy arrays, as numpy represents - // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 - // booleans per byte. - static MlirAttribute getBitpackedAttributeFromBooleanBuffer( - Py_buffer &view, std::optional> explicitShape, - MlirContext &context) { - if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian - // systems we will throw - throw nb::type_error("Constructing a bit-packed MLIR attribute is " - "unsupported on big-endian systems"); + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); } - nb::ndarray, nb::c_contig> unpackedArray( - /*data=*/static_cast(view.buf), - /*shape=*/{static_cast(view.len)}); - - nb::module_ numpy = nb::module_::import_("numpy"); - nb::object packbitsFunc = numpy.attr("packbits"); - nb::object packedBooleans = - packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); - nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); - - MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), - std::move(explicitShape), view); - assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); - // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of - // packedBooleans, hence the MlirAttribute will remain valid even when - // packedBooleans get reclaimed by the end of the function. - return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, - pythonBuffer.ptr); } - // This does the opposite transformation of - // `getBitpackedAttributeFromBooleanBuffer` - std::unique_ptr getBooleanBufferFromBitpackedAttribute() { - if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian - // systems we will throw - throw nb::type_error("Constructing a numpy array from a MLIR attribute " - "is unsupported on big-endian systems"); - } + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); +} - int64_t numBooleans = mlirElementsAttrGetNumElements(*this); - int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); - uint8_t *bitpackedData = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - nb::ndarray, nb::c_contig> packedArray( - /*data=*/bitpackedData, - /*shape=*/{static_cast(numBitpackedBytes)}); - - nb::module_ numpy = nb::module_::import_("numpy"); - nb::object unpackbitsFunc = numpy.attr("unpackbits"); - nb::object equalFunc = numpy.attr("equal"); - nb::object reshapeFunc = numpy.attr("reshape"); - nb::object unpackedBooleans = - unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); - - // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. - // We need to: - // 1. Slice away the padded bits - // 2. Make the boolean array have the correct shape - // 3. Convert the array to a boolean array - unpackedBooleans = unpackedBooleans[nb::slice( - nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; - unpackedBooleans = equalFunc(unpackedBooleans, 1); - - MlirType shapedType = mlirAttributeGetType(*this); - intptr_t rank = mlirShapedTypeGetRank(shapedType); - std::vector shape(rank); - for (intptr_t i = 0; i < rank; ++i) { - shape[i] = mlirShapedTypeGetDimSize(shapedType, i); - } - unpackedBooleans = reshapeFunc(unpackedBooleans, shape); +// There is a complication for boolean numpy arrays, as numpy represents +// them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 +// booleans per byte. +MlirAttribute PyDenseElementsAttribute::getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context) { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a bit-packed MLIR attribute is " + "unsupported on big-endian systems"); + } + nb::ndarray, nb::c_contig> unpackedArray( + /*data=*/static_cast(view.buf), + /*shape=*/{static_cast(view.len)}); + + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object packbitsFunc = numpy.attr("packbits"); + nb::object packedBooleans = + packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); + nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); + + MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), + std::move(explicitShape), view); + assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); + // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of + // packedBooleans, hence the MlirAttribute will remain valid even when + // packedBooleans get reclaimed by the end of the function. + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, + pythonBuffer.ptr); +} - // Make sure the returned nb::buffer_view claims ownership of the data in - // `pythonBuffer` so it remains valid when Python reads it - nb_buffer pythonBuffer = nb::cast(unpackedBooleans); - return std::make_unique(pythonBuffer.request()); +// This does the opposite transformation of +// `getBitpackedAttributeFromBooleanBuffer` +std::unique_ptr +PyDenseElementsAttribute::getBooleanBufferFromBitpackedAttribute() { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a numpy array from a MLIR attribute " + "is unsupported on big-endian systems"); } - template - std::unique_ptr - bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); - // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. - Type *data = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - // Prepare the shape for the buffer_info. - SmallVector shape; - for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); - // Prepare the strides for the buffer_info. - SmallVector strides; - if (mlirDenseElementsAttrIsSplat(*this)) { - // Splats are special, only the single value is stored. - strides.assign(rank, 0); - } else { - for (intptr_t i = 1; i < rank; ++i) { - intptr_t strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); - strides.push_back(sizeof(Type) * strideFactor); - } - strides.push_back(sizeof(Type)); - } - const char *format; - if (explicitFormat) { - format = explicitFormat; - } else { - format = nb_format_descriptor::format(); + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); + uint8_t *bitpackedData = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + nb::ndarray, nb::c_contig> packedArray( + /*data=*/bitpackedData, + /*shape=*/{static_cast(numBitpackedBytes)}); + + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object unpackbitsFunc = numpy.attr("unpackbits"); + nb::object equalFunc = numpy.attr("equal"); + nb::object reshapeFunc = numpy.attr("reshape"); + nb::object unpackedBooleans = + unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); + + // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. + // We need to: + // 1. Slice away the padded bits + // 2. Make the boolean array have the correct shape + // 3. Convert the array to a boolean array + unpackedBooleans = unpackedBooleans[nb::slice( + nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; + unpackedBooleans = equalFunc(unpackedBooleans, 1); + + MlirType shapedType = mlirAttributeGetType(*this); + intptr_t rank = mlirShapedTypeGetRank(shapedType); + std::vector shape(rank); + for (intptr_t i = 0; i < rank; ++i) { + shape[i] = mlirShapedTypeGetDimSize(shapedType, i); + } + unpackedBooleans = reshapeFunc(unpackedBooleans, shape); + + // Make sure the returned nb::buffer_view claims ownership of the data in + // `pythonBuffer` so it remains valid when Python reads it + nb_buffer pythonBuffer = nb::cast(unpackedBooleans); + return std::make_unique(pythonBuffer.request()); +} + +template +std::unique_ptr +PyDenseElementsAttribute::bufferInfo(MlirType shapedType, + const char *explicitFormat) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); + // Prepare the data for the buffer_info. + // Buffer is configured for read-only access below. + Type *data = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + // Prepare the shape for the buffer_info. + SmallVector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + SmallVector strides; + if (mlirDenseElementsAttrIsSplat(*this)) { + // Splats are special, only the single value is stored. + strides.assign(rank, 0); + } else { + for (intptr_t i = 1; i < rank; ++i) { + intptr_t strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strides.push_back(sizeof(Type) * strideFactor); } - return std::make_unique( - data, sizeof(Type), format, rank, std::move(shape), std::move(strides), - /*readonly=*/true); + strides.push_back(sizeof(Type)); } -}; // namespace + const char *format; + if (explicitFormat) { + format = explicitFormat; + } else { + format = nb_format_descriptor::format(); + } + return std::make_unique(data, sizeof(Type), format, rank, + std::move(shape), std::move(strides), + /*readonly=*/true); +} PyType_Slot PyDenseElementsAttribute::slots[] = { // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. @@ -1312,9 +1080,8 @@ PyType_Slot PyDenseElementsAttribute::slots[] = { {0, nullptr}, }; -/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, - Py_buffer *view, - int flags) { +int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, Py_buffer *view, + int flags) { view->obj = nullptr; std::unique_ptr info; try { @@ -1348,85 +1115,71 @@ PyType_Slot PyDenseElementsAttribute::slots[] = { return 0; } -/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, - Py_buffer *view) { +void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, Py_buffer *view) { delete reinterpret_cast(view->internal); } -/// Refinement of the PyDenseElementsAttribute for attributes containing -/// integer (and boolean) values. Supports element access. -class PyDenseIntElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; - static constexpr const char *pyClassName = "DenseIntElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - /// Returns the element at the given linear position. Asserts if the index - /// is out of range. - nb::object dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds element"); - } +nb::object PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds element"); + } - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - // Index type can also appear as a DenseIntElementsAttr and therefore can be - // casted to integer. - assert(mlirTypeIsAInteger(type) || - mlirTypeIsAIndex(type) && "expected integer/index element type in " - "dense int elements attribute"); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. nb::int_ is implicitly constructible - // from any C++ integral type and handles bitwidth correctly. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - if (mlirTypeIsAIndex(type)) { - return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Index type can also appear as a DenseIntElementsAttr and therefore can be + // casted to integer. + assert(mlirTypeIsAInteger(type) || + mlirTypeIsAIndex(type) && "expected integer/index element type in " + "dense int elements attribute"); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. nb::int_ is implicitly constructible + // from any C++ integral type and handles bitwidth correctly. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAIndex(type)) { + return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); + } + unsigned width = mlirIntegerTypeGetWidth(type); + bool isUnsigned = mlirIntegerTypeIsUnsigned(type); + if (isUnsigned) { + if (width == 1) { + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } - unsigned width = mlirIntegerTypeGetWidth(type); - bool isUnsigned = mlirIntegerTypeIsUnsigned(type); - if (isUnsigned) { - if (width == 1) { - return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); - } - if (width == 8) { - return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); - } - if (width == 16) { - return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); - } - if (width == 32) { - return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); - } - if (width == 64) { - return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); - } - } else { - if (width == 1) { - return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); - } - if (width == 8) { - return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); - } - if (width == 16) { - return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); - } - if (width == 32) { - return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); - } - if (width == 64) { - return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); - } + if (width == 8) { + return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); + } + if (width == 16) { + return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); + } + if (width == 32) { + return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); + } + if (width == 64) { + return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); + } + } else { + if (width == 1) { + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); + } + if (width == 8) { + return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); + } + if (width == 16) { + return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); + } + if (width == 32) { + return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); + } + if (width == 64) { + return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); } - throw nb::type_error("Unsupported integer type"); } + throw nb::type_error("Unsupported integer type"); +} - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); - } -}; +void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); +} // Check if the python version is less than 3.13. Py_IsFinalizing is a part // of stable ABI since 3.13 and before it was available as _Py_IsFinalizing. @@ -1434,279 +1187,223 @@ class PyDenseIntElementsAttribute #define Py_IsFinalizing _Py_IsFinalizing #endif -class PyDenseResourceElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = - mlirAttributeIsADenseResourceElements; - static constexpr const char *pyClassName = "DenseResourceElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static PyDenseResourceElementsAttribute - getFromBuffer(const nb_buffer &buffer, const std::string &name, - const PyType &type, std::optional alignment, - bool isMutable, DefaultingPyMlirContext contextWrapper) { - if (!mlirTypeIsAShaped(type)) { - throw std::invalid_argument( - "Constructing a DenseResourceElementsAttr requires a ShapedType."); - } - - // Do not request any conversions as we must ensure to use caller - // managed memory. - int flags = PyBUF_STRIDES; - std::unique_ptr view = std::make_unique(); - if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { - throw nb::python_error(); - } +PyDenseResourceElementsAttribute +PyDenseResourceElementsAttribute::getFromBuffer( + const nb_buffer &buffer, const std::string &name, const PyType &type, + std::optional alignment, bool isMutable, + DefaultingPyMlirContext contextWrapper) { + if (!mlirTypeIsAShaped(type)) { + throw std::invalid_argument( + "Constructing a DenseResourceElementsAttr requires a ShapedType."); + } - // This scope releaser will only release if we haven't yet transferred - // ownership. - auto freeBuffer = llvm::make_scope_exit([&]() { - if (view) - PyBuffer_Release(view.get()); - }); + // Do not request any conversions as we must ensure to use caller + // managed memory. + int flags = PyBUF_STRIDES; + std::unique_ptr view = std::make_unique(); + if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { + throw nb::python_error(); + } - if (!PyBuffer_IsContiguous(view.get(), 'A')) { - throw std::invalid_argument("Contiguous buffer is required."); - } + // This scope releaser will only release if we haven't yet transferred + // ownership. + auto freeBuffer = llvm::make_scope_exit([&]() { + if (view) + PyBuffer_Release(view.get()); + }); - // Infer alignment to be the stride of one element if not explicit. - size_t inferredAlignment; - if (alignment) - inferredAlignment = *alignment; - else - inferredAlignment = view->strides[view->ndim - 1]; - - // The userData is a Py_buffer* that the deleter owns. - auto deleter = [](void *userData, const void *data, size_t size, - size_t align) { - if (Py_IsFinalizing()) - return; - assert(Py_IsInitialized() && "expected interpreter to be initialized"); - Py_buffer *ownedView = static_cast(userData); - nb::gil_scoped_acquire gil; - PyBuffer_Release(ownedView); - delete ownedView; - }; - - size_t rawBufferSize = view->len; - MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( - type, toMlirStringRef(name), view->buf, rawBufferSize, - inferredAlignment, isMutable, deleter, static_cast(view.get())); - if (mlirAttributeIsNull(attr)) { - throw std::invalid_argument( - "DenseResourceElementsAttr could not be constructed from the given " - "buffer. " - "This may mean that the Python buffer layout does not match that " - "MLIR expected layout and is a bug."); - } - view.release(); - return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); + if (!PyBuffer_IsContiguous(view.get(), 'A')) { + throw std::invalid_argument("Contiguous buffer is required."); } - static void bindDerived(ClassTy &c) { - c.def_static( - "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("name"), nb::arg("type"), - nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, - nb::arg("context").none() = nb::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); - } -}; + // Infer alignment to be the stride of one element if not explicit. + size_t inferredAlignment; + if (alignment) + inferredAlignment = *alignment; + else + inferredAlignment = view->strides[view->ndim - 1]; + + // The userData is a Py_buffer* that the deleter owns. + auto deleter = [](void *userData, const void *data, size_t size, + size_t align) { + if (Py_IsFinalizing()) + return; + assert(Py_IsInitialized() && "expected interpreter to be initialized"); + Py_buffer *ownedView = static_cast(userData); + nb::gil_scoped_acquire gil; + PyBuffer_Release(ownedView); + delete ownedView; + }; -class PyDictAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; - static constexpr const char *pyClassName = "DictAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirDictionaryAttrGetTypeID; + size_t rawBufferSize = view->len; + MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( + type, toMlirStringRef(name), view->buf, rawBufferSize, inferredAlignment, + isMutable, deleter, static_cast(view.get())); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseResourceElementsAttr could not be constructed from the given " + "buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + view.release(); + return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); +} - intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } +void PyDenseResourceElementsAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, + nb::arg("context").none() = nb::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); +} - bool dunderContains(const std::string &name) { - return !mlirAttributeIsNull( - mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); - } +intptr_t PyDictAttribute::dunderLen() { + return mlirDictionaryAttrGetNumElements(*this); +} - static void bindDerived(ClassTy &c) { - c.def("__contains__", &PyDictAttribute::dunderContains); - c.def("__len__", &PyDictAttribute::dunderLen); - c.def_static( - "get", - [](const nb::dict &attributes, DefaultingPyMlirContext context) { - SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(attributes.size()); - for (std::pair it : attributes) { - auto &mlirAttr = nb::cast(it.second); - auto name = nb::cast(it.first); - mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), - toMlirStringRef(name)), - mlirAttr)); - } - MlirAttribute attr = - mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - return PyDictAttribute(context->getRef(), attr); - }, - nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), - "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) - throw nb::key_error("attempt to access a non-existent attribute"); - return attr; - }); - c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { - if (index < 0 || index >= self.dunderLen()) { - throw nb::index_error("attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); - }); - } -}; +bool PyDictAttribute::dunderContains(const std::string &name) { + return !mlirAttributeIsNull( + mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); +} -/// Refinement of PyDenseElementsAttribute for attributes containing -/// floating-point values. Supports element access. -class PyDenseFPElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; - static constexpr const char *pyClassName = "DenseFPElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - nb::float_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds element"); +void PyDictAttribute::bindDerived(ClassTy &c) { + c.def("__contains__", &PyDictAttribute::dunderContains); + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](const nb::dict &attributes, DefaultingPyMlirContext context) { + SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(attributes.size()); + for (std::pair it : attributes) { + auto &mlirAttr = nb::cast(it.second); + auto name = nb::cast(it.first); + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), + toMlirStringRef(name)), + mlirAttr)); + } + MlirAttribute attr = + mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + return PyDictAttribute(context->getRef(), attr); + }, + nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) + throw nb::key_error("attempt to access a non-existent attribute"); + return attr; + }); + c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { + if (index < 0 || index >= self.dunderLen()) { + throw nb::index_error("attempt to access out of bounds attribute"); } + MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data)); + }); +} - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. nb::float_ is implicitly constructible - // from float and double. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - if (mlirTypeIsAF32(type)) { - return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); - } - if (mlirTypeIsAF64(type)) { - return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); - } - throw nb::type_error("Unsupported floating-point type"); +nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds element"); } - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. nb::float_ is implicitly constructible + // from float and double. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAF32(type)) { + return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); } -}; - -class PyTypeAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; - static constexpr const char *pyClassName = "TypeAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirTypeAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const PyType &value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirTypeAttrGet(value.get()); - return PyTypeAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued Type attribute"); - c.def_prop_ro("value", [](PyTypeAttribute &self) { - return mlirTypeAttrGetValue(self.get()); - }); + if (mlirTypeIsAF64(type)) { + return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); } -}; + throw nb::type_error("Unsupported floating-point type"); +} -/// Unit Attribute subclass. Unit attributes don't have values. -class PyUnitAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; - static constexpr const char *pyClassName = "UnitAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirUnitAttrGetTypeID; +void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - return PyUnitAttribute(context->getRef(), - mlirUnitAttrGet(context->get())); - }, - nb::arg("context").none() = nb::none(), "Create a Unit attribute."); - } -}; +void PyTypeAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const PyType &value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued Type attribute"); + c.def_prop_ro("value", [](PyTypeAttribute &self) { + return mlirTypeAttrGetValue(self.get()); + }); +} -/// Strided layout attribute subclass. -class PyStridedLayoutAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; - static constexpr const char *pyClassName = "StridedLayoutAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirStridedLayoutAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](int64_t offset, const std::vector &strides, - DefaultingPyMlirContext ctx) { - MlirAttribute attr = mlirStridedLayoutAttrGet( - ctx->get(), offset, strides.size(), strides.data()); - return PyStridedLayoutAttribute(ctx->getRef(), attr); - }, - nb::arg("offset"), nb::arg("strides"), - nb::arg("context").none() = nb::none(), - "Gets a strided layout attribute."); - c.def_static( - "get_fully_dynamic", - [](int64_t rank, DefaultingPyMlirContext ctx) { - auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); - std::vector strides(rank); - llvm::fill(strides, dynamic); - MlirAttribute attr = mlirStridedLayoutAttrGet( - ctx->get(), dynamic, strides.size(), strides.data()); - return PyStridedLayoutAttribute(ctx->getRef(), attr); - }, - nb::arg("rank"), nb::arg("context").none() = nb::none(), - "Gets a strided layout attribute with dynamic offset and strides of " - "a " - "given rank."); - c.def_prop_ro( - "offset", - [](PyStridedLayoutAttribute &self) { - return mlirStridedLayoutAttrGetOffset(self); - }, - "Returns the value of the float point attribute"); - c.def_prop_ro( - "strides", - [](PyStridedLayoutAttribute &self) { - intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); - std::vector strides(size); - for (intptr_t i = 0; i < size; i++) { - strides[i] = mlirStridedLayoutAttrGetStride(self, i); - } - return strides; - }, - "Returns the value of the float point attribute"); - } -}; +void PyUnitAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnitAttribute(context->getRef(), + mlirUnitAttrGet(context->get())); + }, + nb::arg("context").none() = nb::none(), "Create a Unit attribute."); +} -nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { +void PyStridedLayoutAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int64_t offset, const std::vector &strides, + DefaultingPyMlirContext ctx) { + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), offset, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("offset"), nb::arg("strides"), + nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute."); + c.def_static( + "get_fully_dynamic", + [](int64_t rank, DefaultingPyMlirContext ctx) { + auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); + std::vector strides(rank); + llvm::fill(strides, dynamic); + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), dynamic, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("rank"), nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute with dynamic offset and strides of " + "a " + "given rank."); + c.def_prop_ro( + "offset", + [](PyStridedLayoutAttribute &self) { + return mlirStridedLayoutAttrGetOffset(self); + }, + "Returns the value of the float point attribute"); + c.def_prop_ro( + "strides", + [](PyStridedLayoutAttribute &self) { + intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); + std::vector strides(size); + for (intptr_t i = 0; i < size; i++) { + strides[i] = mlirStridedLayoutAttrGetStride(self, i); + } + return strides; + }, + "Returns the value of the float point attribute"); +} + +static nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) @@ -1727,7 +1424,8 @@ nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { throw nb::type_error(msg.c_str()); } -nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { +static nb::object +denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) @@ -1739,7 +1437,7 @@ nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { throw nb::type_error(msg.c_str()); } -nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { +static nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { if (PyBoolAttribute::isaFunction(pyAttribute)) return nb::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) @@ -1750,7 +1448,8 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { throw nb::type_error(msg.c_str()); } -nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { +static nb::object +symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); if (PySymbolRefAttribute::isaFunction(pyAttribute)) @@ -1761,9 +1460,7 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { throw nb::type_error(msg.c_str()); } -} // namespace - -void mlir::python::populateIRAttributes(nb::module_ &m) { +void populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1816,3 +1513,4 @@ void mlir::python::populateIRAttributes(nb::module_ &m) { PyStridedLayoutAttribute::bind(m); } +} // namespace mlir::python diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2df2a73fd88ff..67f1f755e2125 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,21 +6,24 @@ // //===----------------------------------------------------------------------===// -#include "Globals.h" -#include "IRModule.h" -#include "NanobindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/Nanobind.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Globals.h" +#include "mlir/Bindings/Python/IRModule.h" #include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +// clang-format off +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. +// clang-format on + #include namespace nb = nanobind; @@ -1545,81 +1548,47 @@ void PyOperation::erase() { mlirOperationDestroy(operation); } -namespace { -/// CRTP base class for Python MLIR values that subclass Value and should be -/// castable from it. The value hierarchy is one level deep and is not supposed -/// to accommodate other levels unless core MLIR changes. template -class PyConcreteValue : public PyValue { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = nb::class_; - using IsAFunctionTy = bool (*)(MlirValue); - - PyConcreteValue() = default; - PyConcreteValue(PyOperationRef operationRef, MlirValue value) - : PyValue(operationRef, value) {} - PyConcreteValue(PyValue &orig) - : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} - - /// Attempts to cast the original value to the derived type and throws on - /// type mismatches. - static MlirValue castFrom(PyValue &orig) { - if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast value to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str() - .c_str()); - } - return orig.get(); - } - - /// Binds the Python module objects to functions of this class. - static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); - cls.def_static( - "isinstance", - [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }, - nb::arg("other_value")); - cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); - DerivedTy::bindDerived(cls); +MlirValue PyConcreteValue::castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")") + .str() + .c_str()); } + return orig.get(); +} - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -} // namespace - -/// Python wrapper for MlirOpResult. -class PyOpResult : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; - static constexpr const char *pyClassName = "OpResult"; - using PyConcreteValue::PyConcreteValue; +template +void PyConcreteValue::bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + nb::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); + DerivedTy::bindDerived(cls); +} - static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation().getObject(); - }); - c.def_prop_ro("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); - } -}; +template +void PyConcreteValue::bindDerived(ClassTy &m) {} + +void PyOpResult::bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyOpResult &self) { + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in the IR"); + return self.getParentOperation().getObject(); + }); + c.def_prop_ro("result_number", [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }); +} /// Returns the list of types of the values held by container. template @@ -2349,32 +2318,23 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, } } -namespace { - -/// Python wrapper for MlirBlockArgument. -class PyBlockArgument : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; - static constexpr const char *pyClassName = "BlockArgument"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyBlockArgument &self) { - return PyBlock(self.getParentOperation(), - mlirBlockArgumentGetOwner(self.get())); - }); - c.def_prop_ro("arg_number", [](PyBlockArgument &self) { - return mlirBlockArgumentGetArgNumber(self.get()); - }); - c.def( - "set_type", - [](PyBlockArgument &self, PyType type) { - return mlirBlockArgumentSetType(self.get(), type); - }, - nb::arg("type")); - } -}; +void PyBlockArgument::bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }); + c.def_prop_ro("arg_number", [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }); + c.def( + "set_type", + [](PyBlockArgument &self, PyType type) { + return mlirBlockArgumentSetType(self.get(), type); + }, + nb::arg("type")); +} +namespace { /// A list of block arguments. Internally, these are stored as consecutive /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 9e1fedaab5235..2b11513885e32 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -12,11 +12,11 @@ #include #include -#include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir-c/Interfaces.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 0de2f1711829b..c77e37da3ffd4 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -6,16 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include #include -#include "Globals.h" -#include "NanobindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Globals.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. +// clang-format on namespace nb = nanobind; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index a9b12590188f8..df34bfd6f8ab5 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -7,513 +7,290 @@ //===----------------------------------------------------------------------===// // clang-format off -#include "IRModule.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir/Bindings/Python/IRTypes.h" // clang-format on #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace nb = nanobind; -using namespace mlir; -using namespace mlir::python; using llvm::SmallVector; using llvm::Twine; namespace { - /// Checks whether the given type is an integer or float type. -static int mlirTypeIsAIntegerOrFloat(MlirType type) { +int mlirTypeIsAIntegerOrFloat(MlirType type) { return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } +} // namespace -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIntegerTypeGetTypeID; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context").none() = nb::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context").none() = nb::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context").none() = nb::none(), - "Create an unsigned integer type"); - c.def_prop_ro( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_prop_ro( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_prop_ro( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_prop_ro( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIndexTypeGetTypeID; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a index type."); - } -}; - -class PyFloatType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; - static constexpr const char *pyClassName = "FloatType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, - "Returns the width of the floating-point type"); - } -}; - -/// Floating Point Type subclass - Float4E2M1FNType. -class PyFloat4E2M1FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat4E2M1FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float4E2M1FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); - return PyFloat4E2M1FNType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); - } -}; - -/// Floating Point Type subclass - Float6E2M3FNType. -class PyFloat6E2M3FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E2M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E2M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); - return PyFloat6E2M3FNType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); - } -}; - -/// Floating Point Type subclass - Float6E3M2FNType. -class PyFloat6E3M2FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E3M2FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E3M2FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); - return PyFloat6E3M2FNType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2Type. -class PyFloat8E5M2Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3Type. -class PyFloat8E4M3Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3TypeGet(context->get()); - return PyFloat8E4M3Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3B11FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; +namespace mlir::python { +void PyIntegerType::bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create an unsigned integer type"); + c.def_prop_ro( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_prop_ro( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_prop_ro( + "is_signed", + [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, + "Returns whether this is a signed integer"); + c.def_prop_ro( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), - "Create a float8_e5m2fnuz type."); - } -}; +void PyIndexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a index type."); +} -/// Floating Point Type subclass - Float8E3M4Type. -class PyFloat8E3M4Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E3M4TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E3M4Type"; - using PyConcreteType::PyConcreteType; +void PyFloatType::bindDerived(ClassTy &c) { + c.def_prop_ro( + "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E3M4TypeGet(context->get()); - return PyFloat8E3M4Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); - } -}; +void PyFloat4E2M1FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); + return PyFloat4E2M1FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); +} -/// Floating Point Type subclass - Float8E8M0FNUType. -class PyFloat8E8M0FNUType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E8M0FNUTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E8M0FNUType"; - using PyConcreteType::PyConcreteType; +void PyFloat6E2M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); + return PyFloat6E2M3FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); - return PyFloat8E8M0FNUType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), - "Create a float8_e8m0fnu type."); - } -}; +void PyFloat6E3M2FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); + return PyFloat6E3M2FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); +} -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirBFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; +void PyFloat8E4M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a bf16 type."); - } -}; +void PyFloat8E5M2Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); +} -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; +void PyFloat8E4M3Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3TypeGet(context->get()); + return PyFloat8E4M3Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a f16 type."); - } -}; +void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fnuz type."); +} -/// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloatTF32TypeGetTypeID; - static constexpr const char *pyClassName = "FloatTF32Type"; - using PyConcreteType::PyConcreteType; +void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3b11fnuz type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirTF32TypeGet(context->get()); - return PyTF32Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a tf32 type."); - } -}; +void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e5m2fnuz type."); +} -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat32TypeGetTypeID; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; +void PyFloat8E3M4Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E3M4TypeGet(context->get()); + return PyFloat8E3M4Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a f32 type."); - } -}; +void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); + return PyFloat8E8M0FNUType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e8m0fnu type."); +} -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat64TypeGetTypeID; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; +void PyBF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a bf16 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a f64 type."); - } -}; +void PyF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f16 type."); +} -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirNoneTypeGetTypeID; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; +void PyTF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirTF32TypeGet(context->get()); + return PyTF32Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a tf32 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - nb::arg("context").none() = nb::none(), "Create a none type."); - } -}; +void PyF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f32 type."); +} -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirComplexTypeGetTypeID; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; +void PyF64Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f64 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw nb::value_error( - (Twine("invalid '") + - nb::cast(nb::repr(nb::cast(elementType))) + - "' and expected floating point or integer type.") - .str() - .c_str()); - }, - "Create a complex type"); - c.def_prop_ro( - "element_type", - [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, - "Returns element type."); - } -}; +void PyNoneType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a none type."); +} -} // namespace +void PyComplexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw nb::value_error( + (Twine("invalid '") + + nb::cast(nb::repr(nb::cast(elementType))) + + "' and expected floating point or integer type.") + .str() + .c_str()); + }, + "Create a complex type"); + c.def_prop_ro( + "element_type", + [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, + "Returns element type."); +} // Shaped Type Interface - ShapedType -void mlir::PyShapedType::bindDerived(ClassTy &c) { +void PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( "element_type", [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, @@ -534,7 +311,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { [](PyShapedType &self) -> bool { return mlirShapedTypeHasStaticShape(self); }, - "Returns whether the given shaped type has a static shape."); + "Returns whether the given shaped type has a shape."); c.def( "is_dynamic_dim", [](PyShapedType &self, intptr_t dim) -> bool { @@ -571,7 +348,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { "is_static_size", [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); }, nb::arg("dim_size"), - "Returns whether the given dimension size indicates a static " + "Returns whether the given dimension size indicates a " "dimension."); c.def( "is_dynamic_stride_or_offset", @@ -615,383 +392,294 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { "shaped types."); } -void mlir::PyShapedType::requireHasRank() { +void PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { throw nb::value_error( "calling this method requires that the type has a rank."); } } -const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = - mlirTypeIsAShaped; - -namespace { - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirVectorTypeGetTypeID; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; +const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped; + +void PyVectorType::bindDerived(ClassTy &c) { + c.def_static("get", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable").none() = nb::none(), + nb::arg("scalable_dims").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a vector type") + .def_prop_ro("scalable", + [](MlirType self) { return mlirVectorTypeIsScalable(self); }) + .def_prop_ro("scalable_dims", [](MlirType self) { + std::vector scalableDims; + size_t rank = static_cast(mlirShapedTypeGetRank(self)); + scalableDims.reserve(rank); + for (size_t i = 0; i < rank; ++i) + scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); + return scalableDims; + }); +} - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, nb::arg("shape"), - nb::arg("element_type"), nb::kw_only(), - nb::arg("scalable").none() = nb::none(), - nb::arg("scalable_dims").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a vector type") - .def_prop_ro( - "scalable", - [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_prop_ro("scalable_dims", [](MlirType self) { - std::vector scalableDims; - size_t rank = static_cast(mlirShapedTypeGetRank(self)); - scalableDims.reserve(rank); - for (size_t i = 0; i < rank; ++i) - scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); - return scalableDims; - }); +PyVectorType PyVectorType::get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc) { + if (scalable && scalableDims) { + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); } -private: - static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, - std::optional> scalableDims, - DefaultingPyLocation loc) { - if (scalable && scalableDims) { - throw nb::value_error("'scalable' and 'scalable_dims' kwargs " - "are mutually exclusive."); + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nb::value_error("Expected len(scalable) == len(shape)."); + + SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( + *scalable, [](const nb::handle &h) { return nb::cast(h); })); + type = mlirVectorTypeGetScalableChecked( + loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); + } else if (scalableDims) { + SmallVector scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) + throw nb::value_error("Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; } - - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType type; - if (scalable) { - if (scalable->size() != shape.size()) - throw nb::value_error("Expected len(scalable) == len(shape)."); - - SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const nb::handle &h) { return nb::cast(h); })); - type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), - scalableDimFlags.data(), - elementType); - } else if (scalableDims) { - SmallVector scalableDimFlags(shape.size(), false); - for (int64_t dim : *scalableDims) { - if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) - throw nb::value_error("Scalable dimension index out of bounds."); - scalableDimFlags[dim] = true; - } - type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), - scalableDimFlags.data(), - elementType); - } else { - type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - } - if (mlirTypeIsNull(type)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), type); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirRankedTensorTypeGetTypeID; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::optional &encodingAttr, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, - encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyRankedTensorType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("encoding").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); - c.def_prop_ro("encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = - mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + type = mlirVectorTypeGetScalableChecked( + loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); + } else { + type = + mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType); } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirUnrankedTensorTypeGetTypeID; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedTensorType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("loc").none() = nb::none(), - "Create a unranked tensor type"); - } -}; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirMemRefTypeGetTypeID; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - PyAttribute *layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); - MlirAttribute memSpaceAttr = - memorySpace ? *memorySpace : mlirAttributeGetNull(); - MlirType t = - mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), layoutAttr, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyMemRefType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("layout").none() = nb::none(), - nb::arg("memory_space").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a memref type") - .def_prop_ro( - "layout", - [](PyMemRefType &self) -> MlirAttribute { - return mlirMemRefTypeGetLayout(self); - }, - "The layout of the MemRef type.") - .def( - "get_strides_and_offset", - [](PyMemRefType &self) -> std::pair, int64_t> { - std::vector strides(mlirShapedTypeGetRank(self)); - int64_t offset; - if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( - self, strides.data(), &offset))) - throw std::runtime_error( - "Failed to extract strides and offset from memref."); - return {strides, offset}; - }, - "The strides and offset of the MemRef type.") - .def_prop_ro( - "affine_map", - [](PyMemRefType &self) -> PyAffineMap { - MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); - return PyAffineMap(self.getContext(), map); - }, - "The layout of the MemRef type as an affine map.") - .def_prop_ro( - "memory_space", - [](PyMemRefType &self) -> std::optional { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - if (mlirAttributeIsNull(a)) - return std::nullopt; - return a; - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirUnrankedMemRefTypeGetTypeID; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; +void PyRankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); + c.def_prop_ro( + "encoding", [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); +} - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("memory_space").none(), - nb::arg("loc").none() = nb::none(), "Create a unranked memref type") - .def_prop_ro( - "memory_space", - [](PyUnrankedMemRefType &self) -> std::optional { - MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); - if (mlirAttributeIsNull(a)) - return std::nullopt; - return a; - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; +void PyUnrankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("loc").none() = nb::none(), + "Create a unranked tensor type"); +} -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirTupleTypeGetTypeID; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; +void PyMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, PyAttribute *layout, + PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout").none() = nb::none(), + nb::arg("memory_space").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a memref type") + .def_prop_ro( + "layout", + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); + }, + "The layout of the MemRef type.") + .def( + "get_strides_and_offset", + [](PyMemRefType &self) -> std::pair, int64_t> { + std::vector strides(mlirShapedTypeGetRank(self)); + int64_t offset; + if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( + self, strides.data(), &offset))) + throw std::runtime_error( + "Failed to extract strides and offset from memref."); + return {strides, offset}; + }, + "The strides and offset of the MemRef type.") + .def_prop_ro( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") + .def_prop_ro( + "memory_space", + [](PyMemRefType &self) -> std::optional { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; + }, + "Returns the memory space of the given MemRef type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](std::vector elements, DefaultingPyMlirContext context) { - MlirType t = mlirTupleTypeGet(context->get(), elements.size(), - elements.data()); - return PyTupleType(context->getRef(), t); - }, - nb::arg("elements"), nb::arg("context").none() = nb::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) { - return mlirTupleTypeGetType(self, pos); - }, - nb::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_prop_ro( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; +void PyUnrankedMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + .def_prop_ro( + "memory_space", + [](PyUnrankedMemRefType &self) -> std::optional { + MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; + }, + "Returns the memory space of the given Unranked MemRef type."); +} -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFunctionTypeGetTypeID; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; +void PyTupleType::bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](std::vector elements, DefaultingPyMlirContext context) { + MlirType t = + mlirTupleTypeGet(context->get(), elements.size(), elements.data()); + return PyTupleType(context->getRef(), t); + }, + nb::arg("elements"), nb::arg("context").none() = nb::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); + }, + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - MlirType t = - mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), - results.size(), results.data()); - return PyFunctionType(context->getRef(), t); - }, - nb::arg("inputs"), nb::arg("results"), - nb::arg("context").none() = nb::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_prop_ro( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetInput(t, i)); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_prop_ro( - "results", - [](PyFunctionType &self) { - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetResult(self, i)); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; +void PyFunctionType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); + return PyFunctionType(context->getRef(), t); + }, + nb::arg("inputs"), nb::arg("results"), + nb::arg("context").none() = nb::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_prop_ro( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetInput(t, i)); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_prop_ro( + "results", + [](PyFunctionType &self) { + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetResult(self, i)); + } + return types; + }, + "Returns the list of result types in the FunctionType."); +} +}; // namespace mlir::python -static MlirStringRef toMlirStringRef(const std::string &s) { +namespace { +MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } - -/// Opaque Type subclass - OpaqueType. -class PyOpaqueType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirOpaqueTypeGetTypeID; - static constexpr const char *pyClassName = "OpaqueType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &dialectNamespace, const std::string &typeData, - DefaultingPyMlirContext context) { - MlirType type = mlirOpaqueTypeGet(context->get(), - toMlirStringRef(dialectNamespace), - toMlirStringRef(typeData)); - return PyOpaqueType(context->getRef(), type); - }, - nb::arg("dialect_namespace"), nb::arg("buffer"), - nb::arg("context").none() = nb::none(), - "Create an unregistered (opaque) dialect type."); - c.def_prop_ro( - "dialect_namespace", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque type as a string."); - c.def_prop_ro( - "data", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaque type as a string."); - } -}; - } // namespace +namespace mlir::python { +void PyOpaqueType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &dialectNamespace, const std::string &typeData, + DefaultingPyMlirContext context) { + MlirType type = + mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context").none() = nb::none(), + "Create an unregistered (opaque) dialect type."); + c.def_prop_ro( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_prop_ro( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); +} +} // namespace mlir::python + void mlir::python::populateIRTypes(nb::module_ &m) { PyIntegerType::bind(m); PyFloatType::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7ac7f5..0be68e730e186 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "Globals.h" -#include "IRModule.h" -#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" +#include "mlir/Bindings/Python/Globals.h" +#include "mlir/Bindings/Python/IRModule.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace nb = nanobind; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 88e28dca76bb9..d9f053f0cf3ee 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,10 +8,12 @@ #include "Pass.h" -#include "IRModule.h" #include "mlir-c/Pass.h" +#include "mlir/Bindings/Python/IRModule.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. +// clang-format on namespace nb = nanobind; using namespace nb::literals; diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index bc40943521829..0221bd10e723e 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace mlir { namespace python { diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 0373f9c7affe9..28f050bc05562 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,11 +8,13 @@ #include "Rewrite.h" -#include "IRModule.h" #include "mlir-c/Rewrite.h" -#include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir/Bindings/Python/IRModule.h" #include "mlir/Config/mlir-config.h" +// clang-format off +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // ON WINDOWS This is expected after nanobind. +// clang-format on namespace nb = nanobind; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index ae89e2b9589f1..f8ffdc7bdc458 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace mlir { namespace python { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb8200..56327cbe4a463 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -488,10 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Rewrite.cpp # Headers must be included explicitly so they are installed. - Globals.h - IRModule.h Pass.h - NanobindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport @@ -698,8 +695,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectSMT.cpp - # Headers must be included explicitly so they are installed. - NanobindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS