Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 61cc91f

Browse files
[mlir][python] Add bindings for mlirDenseElementsAttrGet (#91389)
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
1 parent 29c36c5 commit 61cc91f

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

IRAttributes.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "PybindUtils.h"
1616

1717
#include "llvm/ADT/ScopeExit.h"
18+
#include "llvm/Support/raw_ostream.h"
1819

1920
#include "mlir-c/BuiltinAttributes.h"
2021
#include "mlir-c/BuiltinTypes.h"
@@ -72,6 +73,27 @@ or 255), then a splat will be created.
7273
type or if the buffer does not meet expectations.
7374
)";
7475

76+
static const char kDenseElementsAttrGetFromListDocstring[] =
77+
R"(Gets a DenseElementsAttr from a Python list of attributes.
78+
79+
Note that it can be expensive to construct attributes individually.
80+
For a large number of elements, consider using a Python buffer or array instead.
81+
82+
Args:
83+
attrs: A list of attributes.
84+
type: The desired shape and type of the resulting DenseElementsAttr.
85+
If not provided, the element type is determined based on the type
86+
of the 0th attribute and the shape is `[len(attrs)]`.
87+
context: Explicit context, if not from context manager.
88+
89+
Returns:
90+
DenseElementsAttr on success.
91+
92+
Raises:
93+
ValueError: If the type of the attributes does not match the type
94+
specified by `shaped_type`.
95+
)";
96+
7597
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
7698
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
7799
@@ -647,6 +669,57 @@ class PyDenseElementsAttribute
647669
static constexpr const char *pyClassName = "DenseElementsAttr";
648670
using PyConcreteAttribute::PyConcreteAttribute;
649671

672+
static PyDenseElementsAttribute
673+
getFromList(py::list attributes, std::optional<PyType> explicitType,
674+
DefaultingPyMlirContext contextWrapper) {
675+
676+
const size_t numAttributes = py::len(attributes);
677+
if (numAttributes == 0)
678+
throw py::value_error("Attributes list must be non-empty.");
679+
680+
MlirType shapedType;
681+
if (explicitType) {
682+
if ((!mlirTypeIsAShaped(*explicitType) ||
683+
!mlirShapedTypeHasStaticShape(*explicitType))) {
684+
685+
std::string message;
686+
llvm::raw_string_ostream os(message);
687+
os << "Expected a static ShapedType for the shaped_type parameter: "
688+
<< py::repr(py::cast(*explicitType));
689+
throw py::value_error(os.str());
690+
}
691+
shapedType = *explicitType;
692+
} else {
693+
SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
694+
shapedType = mlirRankedTensorTypeGet(
695+
shape.size(), shape.data(),
696+
mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
697+
mlirAttributeGetNull());
698+
}
699+
700+
SmallVector<MlirAttribute> mlirAttributes;
701+
mlirAttributes.reserve(numAttributes);
702+
for (const py::handle &attribute : attributes) {
703+
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
704+
MlirType attrType = mlirAttributeGetType(mlirAttribute);
705+
mlirAttributes.push_back(mlirAttribute);
706+
707+
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
708+
std::string message;
709+
llvm::raw_string_ostream os(message);
710+
os << "All attributes must be of the same type and match "
711+
<< "the type parameter: expected=" << py::repr(py::cast(shapedType))
712+
<< ", but got=" << py::repr(py::cast(attrType));
713+
throw py::value_error(os.str());
714+
}
715+
}
716+
717+
MlirAttribute elements = mlirDenseElementsAttrGet(
718+
shapedType, mlirAttributes.size(), mlirAttributes.data());
719+
720+
return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
721+
}
722+
650723
static PyDenseElementsAttribute
651724
getFromBuffer(py::buffer array, bool signless,
652725
std::optional<PyType> explicitType,
@@ -883,6 +956,10 @@ class PyDenseElementsAttribute
883956
py::arg("type") = py::none(), py::arg("shape") = py::none(),
884957
py::arg("context") = py::none(),
885958
kDenseElementsAttrGetDocstring)
959+
.def_static("get", PyDenseElementsAttribute::getFromList,
960+
py::arg("attrs"), py::arg("type") = py::none(),
961+
py::arg("context") = py::none(),
962+
kDenseElementsAttrGetFromListDocstring)
886963
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
887964
py::arg("shaped_type"), py::arg("element_attr"),
888965
"Gets a DenseElementsAttr where all values are the same")

0 commit comments

Comments
 (0)