diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 0992285f997ea..26c4140757c3c 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -45,6 +45,13 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MlirType const *argumentTypes, bool isVarArg); +/// Returns the number of input types. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); + +/// Returns the pos-th input type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, + intptr_t pos); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 6ed82ba1a0250..da450dd3fd8a3 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -55,6 +55,16 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); } +intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) { + return llvm::cast(unwrap(type)).getNumParams(); +} + +MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap(llvm::cast(unwrap(type)) + .getParamType(static_cast(pos))); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); }