diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 7c8c84e55b962..1d0edf9ea809d 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED double diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index cc9532f4e33b2..f05f9f02e50fa 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1172,13 +1172,19 @@ class PyDenseIntElementsAttribute MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); + // Index type can also appear as a DenseIntElementsAttr and therefore can be + // casted to integer. + assert(mlirTypeIsAInteger(type) || + mlirTypeIsAIndex(type) && "expected integer/index element type in " + "dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the // elemental type of the attribute. py::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. + if (mlirTypeIsAIndex(type)) { + return mlirDenseElementsAttrGetIndexValue(*this, pos); + } unsigned width = mlirIntegerTypeGetWidth(type); bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 11d1ade552f5a..8d57ab6b59e79 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } +uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py index 18ebba61e7fea..973a0eaeca2cd 100644 --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -246,3 +246,7 @@ def testDenseElementsAttr(): # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32))) # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32> + idx_values = np.arange(4, dtype=np.int64) + idx_type = IndexType.get() + print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type))) + # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex> diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py index 256a69a939658..ef1d835fc6401 100644 --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -572,6 +572,10 @@ def testGetDenseElementsIndex(): print(arr) # CHECK: True print(arr.dtype == np.int64) + array = np.array([1, 2, 3], dtype=np.int64) + attr = DenseIntElementsAttr.get(array, type=VectorType.get([3], idx_type)) + # CHECK: [1, 2, 3] + print(list(DenseIntElementsAttr(attr))) # CHECK-LABEL: TEST: testGetDenseResourceElementsAttr diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py index 00c3e1b4decdb..2f3c4460d3f59 100644 --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -366,6 +366,10 @@ def testDenseIntAttr(): # CHECK: i1 print(ShapedType(a.type).element_type) + shape = Attribute.parse("dense<[0, 1, 2, 3]> : vector<4xindex>") + # CHECK: attr: dense<[0, 1, 2, 3]> + print("attr:", shape) + @run def testDenseArrayGetItem():