diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 26c4140757c3c..65b14254e4492 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -52,6 +52,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos); +/// Returns the return type of the function type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h index effa74b905ce6..b5f948961e898 100644 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -16,6 +16,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "llvm-c/Core.h" #include "llvm-c/Support.h" #ifdef __cplusplus @@ -32,6 +33,48 @@ extern "C" { MLIR_CAPI_EXPORTED LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); +struct MlirTypeFromLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator; + +/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx); + +/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeFromLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator); + +/// Translates the given LLVM IR type to the MLIR LLVM dialect. +MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType); + +struct MlirTypeToLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator; + +/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx); + +/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeToLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator); + +/// Translates the given MLIR LLVM dialect to the LLVM IR type. +MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType( + MlirTypeToLLVMIRTranslator translator, MlirType mlirType); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index da450dd3fd8a3..69c804b7667f3 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -65,6 +65,10 @@ MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { .getParamType(static_cast(pos))); } +MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getReturnType()); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp index dc798372be746..5e2bba8be4562 100644 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -8,16 +8,15 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Target/LLVMIR.h" -#include "llvm-c/Support.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include +#include "llvm/IR/Type.h" #include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" using namespace mlir; @@ -34,3 +33,47 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, return moduleRef; } + +DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, + mlir::LLVM::TypeFromLLVMIRTranslator); + +MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { + MLIRContext *context = unwrap(ctx); + auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeFromLLVMIRTranslatorDestroy( + MlirTypeFromLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) { + LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator); + mlir::Type type = translator_->translateType(llvm::unwrap(llvmType)); + return wrap(type); +} + +DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, + mlir::LLVM::TypeToLLVMIRTranslator); + +MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { + llvm::LLVMContext *context = llvm::unwrap(ctx); + auto *translator = new LLVM::TypeToLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +LLVMTypeRef +mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator, + MlirType mlirType) { + LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator); + llvm::Type *type = translator_->translateType(unwrap(mlirType)); + return llvm::wrap(type); +} diff --git a/mlir/test/CAPI/translation.c b/mlir/test/CAPI/translation.c index c9233d95fd512..8891c2a559b7e 100644 --- a/mlir/test/CAPI/translation.c +++ b/mlir/test/CAPI/translation.c @@ -58,11 +58,38 @@ static void testToLLVMIR(MlirContext ctx) { LLVMContextDispose(llvmCtx); } +// CHECK-LABEL: testTypeToFromLLVMIRTranslator +static void testTypeToFromLLVMIRTranslator(MlirContext ctx) { + fprintf(stderr, "testTypeToFromLLVMIRTranslator\n"); + LLVMContextRef llvmCtx = LLVMContextCreate(); + + LLVMTypeRef llvmTy = LLVMInt32TypeInContext(llvmCtx); + MlirTypeFromLLVMIRTranslator fromLLVMTranslator = + mlirTypeFromLLVMIRTranslatorCreate(ctx); + MlirType mlirTy = + mlirTypeFromLLVMIRTranslatorTranslateType(fromLLVMTranslator, llvmTy); + // CHECK: i32 + mlirTypeDump(mlirTy); + + MlirTypeToLLVMIRTranslator toLLVMTranslator = + mlirTypeToLLVMIRTranslatorCreate(llvmCtx); + LLVMTypeRef llvmTy2 = + mlirTypeToLLVMIRTranslatorTranslateType(toLLVMTranslator, mlirTy); + // CHECK: i32 + LLVMDumpType(llvmTy2); + fprintf(stderr, "\n"); + + mlirTypeFromLLVMIRTranslatorDestroy(fromLLVMTranslator); + mlirTypeToLLVMIRTranslatorDestroy(toLLVMTranslator); + LLVMContextDispose(llvmCtx); +} + int main(void) { MlirContext ctx = mlirContextCreate(); mlirDialectHandleRegisterDialect(mlirGetDialectHandle__llvm__(), ctx); mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("llvm")); testToLLVMIR(ctx); + testTypeToFromLLVMIRTranslator(ctx); mlirContextDestroy(ctx); return 0; }