Skip to content

Commit e2f79b5

Browse files
maksleventalaadeshps-mcw
authored andcommitted
[MLIR][Python] add GetTypeID for llvm.struct_type and llvm.ptr and enable downcasting (llvm#169383)
1 parent 72a91fe commit e2f79b5

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

mlir/include/mlir-c/Dialect/LLVM.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
2323
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
2424
unsigned addressSpace);
2525

26+
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID(void);
27+
2628
/// Returns `true` if the type is an LLVM dialect pointer type.
2729
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
2830

@@ -58,6 +60,8 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type);
5860
/// Returns `true` if the type is an LLVM dialect struct type.
5961
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
6062

63+
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(void);
64+
6165
/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
6266
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
6367

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
3131
// StructType
3232
//===--------------------------------------------------------------------===//
3333

34-
auto llvmStructType =
35-
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
34+
auto llvmStructType = mlir_type_subclass(
35+
m, "StructType", mlirTypeIsALLVMStructType, mlirLLVMStructTypeGetTypeID);
3636

3737
llvmStructType
3838
.def_classmethod(
@@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
137137
// PointerType
138138
//===--------------------------------------------------------------------===//
139139

140-
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
140+
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
141+
mlirLLVMPointerTypeGetTypeID)
141142
.def_classmethod(
142143
"get",
143144
[](const nb::object &cls, std::optional<unsigned> addressSpace,

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
2727
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
2828
}
2929

30+
MlirTypeID mlirLLVMPointerTypeGetTypeID() {
31+
return wrap(LLVM::LLVMPointerType::getTypeID());
32+
}
33+
3034
bool mlirTypeIsALLVMPointerType(MlirType type) {
3135
return isa<LLVM::LLVMPointerType>(unwrap(type));
3236
}
@@ -73,6 +77,10 @@ bool mlirTypeIsALLVMStructType(MlirType type) {
7377
return isa<LLVM::LLVMStructType>(unwrap(type));
7478
}
7579

80+
MlirTypeID mlirLLVMStructTypeGetTypeID() {
81+
return wrap(LLVM::LLVMStructType::getTypeID());
82+
}
83+
7684
bool mlirLLVMStructTypeIsLiteral(MlirType type) {
7785
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
7886
}

mlir/test/python/dialects/llvm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def testStructType():
9898
assert opaque.opaque
9999
# CHECK: !llvm.struct<"opaque", opaque>
100100

101+
typ = Type.parse('!llvm.struct<"zoo", (i32, i64)>')
102+
assert isinstance(typ, llvm.StructType)
103+
101104

102105
# CHECK-LABEL: testSmoke
103106
@constructAndPrintInModule
@@ -120,6 +123,9 @@ def testPointerType():
120123
# CHECK: !llvm.ptr<1>
121124
print(ptr_with_addr)
122125

126+
typ = Type.parse("!llvm.ptr<1>")
127+
assert isinstance(typ, llvm.PointerType)
128+
123129

124130
# CHECK-LABEL: testConstant
125131
@constructAndPrintInModule

0 commit comments

Comments
 (0)