Skip to content

Commit 060fe70

Browse files
committed
add pointer
1 parent b8d0de7 commit 060fe70

File tree

4 files changed

+11
-1
lines changed

4 files changed

+11
-1
lines changed

mlir/include/mlir-c/Dialect/LLVM.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
2323
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
2424
unsigned addressSpace);
2525

26+
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID();
27+
2628
/// Returns `true` if the type is an LLVM dialect pointer type.
2729
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
2830

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
137137
// PointerType
138138
//===--------------------------------------------------------------------===//
139139

140-
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
140+
mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType,
141+
mlirLLVMPointerTypeGetTypeID)
141142
.def_classmethod(
142143
"get",
143144
[](const nb::object &cls, std::optional<unsigned> addressSpace,

mlir/lib/CAPI/Dialect/LLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
2727
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
2828
}
2929

30+
MlirTypeID mlirLLVMPointerTypeGetTypeID() {
31+
return wrap(LLVM::LLVMPointerType::getTypeID());
32+
}
33+
3034
bool mlirTypeIsALLVMPointerType(MlirType type) {
3135
return isa<LLVM::LLVMPointerType>(unwrap(type));
3236
}

mlir/test/python/dialects/llvm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def testPointerType():
123123
# CHECK: !llvm.ptr<1>
124124
print(ptr_with_addr)
125125

126+
typ = Type.parse("!llvm.ptr<1>")
127+
assert isinstance(typ, llvm.PointerType)
128+
126129

127130
# CHECK-LABEL: testConstant
128131
@constructAndPrintInModule

0 commit comments

Comments
 (0)