Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit e708617

Browse files
mgehre-amdcmcgirr-amdttjost
authored
[mlir][python] allow DenseIntElementsAttr for index type (#118947)
Model the `IndexType` as `uint64_t` when converting to a python integer. With the python bindings, ```python DenseIntElementsAttr(op.attributes["attr"]) ``` used to `assert` when `attr` had `index` type like `dense<[1, 2, 3, 4]> : vector<4xindex>`. --------- Co-authored-by: Christopher McGirr <[email protected]> Co-authored-by: Tiago Trevisan Jost <[email protected]>
1 parent a2056b6 commit e708617

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t
556556
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
557557
MLIR_CAPI_EXPORTED uint64_t
558558
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
559+
MLIR_CAPI_EXPORTED uint64_t
560+
mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos);
559561
MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,
560562
intptr_t pos);
561563
MLIR_CAPI_EXPORTED double

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,13 +1372,19 @@ class PyDenseIntElementsAttribute
13721372

13731373
MlirType type = mlirAttributeGetType(*this);
13741374
type = mlirShapedTypeGetElementType(type);
1375-
assert(mlirTypeIsAInteger(type) &&
1376-
"expected integer element type in dense int elements attribute");
1375+
// Index type can also appear as a DenseIntElementsAttr and therefore can be
1376+
// casted to integer.
1377+
assert(mlirTypeIsAInteger(type) ||
1378+
mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379+
"dense int elements attribute");
13771380
// Dispatch element extraction to an appropriate C function based on the
13781381
// elemental type of the attribute. nb::int_ is implicitly constructible
13791382
// from any C++ integral type and handles bitwidth correctly.
13801383
// TODO: consider caching the type properties in the constructor to avoid
13811384
// querying them on each element access.
1385+
if (mlirTypeIsAIndex(type)) {
1386+
return mlirDenseElementsAttrGetIndexValue(*this, pos);
1387+
}
13821388
unsigned width = mlirIntegerTypeGetWidth(type);
13831389
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
13841390
if (isUnsigned) {

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
758758
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
759759
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
760760
}
761+
uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
762+
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
763+
}
761764
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
762765
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
763766
}

0 commit comments

Comments
 (0)