|
16 | 16 | #include "flang/Optimizer/Dialect/FIRType.h" |
17 | 17 | #include "flang/Runtime/main.h" |
18 | 18 | #include "flang/Runtime/stop.h" |
| 19 | +#ifdef FLANG_CUDA_SUPPORT |
| 20 | +#include "flang/Runtime/CUDA/init.h" |
| 21 | +#endif |
19 | 22 |
|
20 | 23 | using namespace Fortran::runtime; |
21 | 24 |
|
22 | 25 | /// Create a `int main(...)` that calls the Fortran entry point |
23 | 26 | void fir::runtime::genMain( |
24 | 27 | fir::FirOpBuilder &builder, mlir::Location loc, |
25 | | - const std::vector<Fortran::lower::EnvironmentDefault> &defs) { |
| 28 | + const std::vector<Fortran::lower::EnvironmentDefault> &defs, |
| 29 | + bool initCuda) { |
26 | 30 | auto *context = builder.getContext(); |
27 | 31 | auto argcTy = builder.getDefaultIntegerType(); |
28 | 32 | auto ptrTy = mlir::LLVM::LLVMPointerType::get(context); |
@@ -61,6 +65,15 @@ void fir::runtime::genMain( |
61 | 65 | args.push_back(env); |
62 | 66 |
|
63 | 67 | builder.create<fir::CallOp>(loc, startFn, args); |
| 68 | + |
| 69 | +#ifdef FLANG_CUDA_SUPPORT |
| 70 | + if (initCuda) { |
| 71 | + auto initFn = builder.createFunction( |
| 72 | + loc, RTNAME_STRING(CUFInit), mlir::FunctionType::get(context, {}, {})); |
| 73 | + builder.create<fir::CallOp>(loc, initFn); |
| 74 | + } |
| 75 | +#endif |
| 76 | + |
64 | 77 | builder.create<fir::CallOp>(loc, qqMainFn); |
65 | 78 | builder.create<fir::CallOp>(loc, stopFn); |
66 | 79 |
|
|
0 commit comments