@@ -25,8 +25,8 @@ using namespace jit_compiler;
2525using FusedFunction = helper::FusionHelper::FusedFunction;
2626using FusedFunctionList = std::vector<FusedFunction>;
2727
28- static JITResult errorToFusionResult (llvm::Error &&Err,
29- const std::string &Msg) {
28+ template < typename ResultType>
29+ static ResultType errorTo (llvm::Error &&Err, const std::string &Msg) {
3030 std::stringstream ErrMsg;
3131 ErrMsg << Msg << " \n Detailed information:\n " ;
3232 llvm::handleAllErrors (std::move (Err),
@@ -35,7 +35,7 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
3535 // compiled without exception support.
3636 ErrMsg << " \t " << StrErr.getMessage () << " \n " ;
3737 });
38- return JITResult {ErrMsg.str ().c_str ()};
38+ return ResultType {ErrMsg.str ().c_str ()};
3939}
4040
4141static std::vector<jit_compiler::NDRange>
@@ -95,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
9595 translation::KernelTranslator::loadKernels (*JITCtx.getLLVMContext (),
9696 ModuleInfo.kernels ());
9797 if (auto Error = ModOrError.takeError ()) {
98- return errorToFusionResult (std::move (Error), " Failed to load kernels" );
98+ return errorTo<JITResult> (std::move (Error), " Failed to load kernels" );
9999 }
100100 std::unique_ptr<llvm::Module> NewMod = std::move (*ModOrError);
101101 if (!fusion::FusionPipeline::runMaterializerPasses (
@@ -107,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
107107 SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor (KernelName);
108108 if (auto Error = translation::KernelTranslator::translateKernel (
109109 MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
110- return errorToFusionResult (std::move (Error),
111- " Translation to output format failed" );
110+ return errorTo<JITResult> (std::move (Error),
111+ " Translation to output format failed" );
112112 }
113113
114114 return JITResult{MaterializerKernelInfo};
@@ -133,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
133133 llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
134134 jit_compiler::FusedNDRange::get (NDRanges);
135135 if (llvm::Error Err = FusedNDR.takeError ()) {
136- return errorToFusionResult (std::move (Err), " Illegal ND-range combination" );
136+ return errorTo<JITResult> (std::move (Err), " Illegal ND-range combination" );
137137 }
138138
139139 if (!isTargetFormatSupported (TargetFormat)) {
@@ -180,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
180180 translation::KernelTranslator::loadKernels (*JITCtx.getLLVMContext (),
181181 ModuleInfo.kernels ());
182182 if (auto Error = ModOrError.takeError ()) {
183- return errorToFusionResult (std::move (Error), " SPIR-V translation failed" );
183+ return errorTo<JITResult> (std::move (Error), " SPIR-V translation failed" );
184184 }
185185 std::unique_ptr<llvm::Module> LLVMMod = std::move (*ModOrError);
186186
@@ -197,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
197197 llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
198198 helper::FusionHelper::addFusedKernel (LLVMMod.get (), FusedKernelList);
199199 if (auto Error = NewModOrError.takeError ()) {
200- return errorToFusionResult (std::move (Error),
201- " Insertion of fused kernel stub failed" );
200+ return errorTo<JITResult> (std::move (Error),
201+ " Insertion of fused kernel stub failed" );
202202 }
203203 std::unique_ptr<llvm::Module> NewMod = std::move (*NewModOrError);
204204
@@ -221,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
221221
222222 if (auto Error = translation::KernelTranslator::translateKernel (
223223 FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
224- return errorToFusionResult (std::move (Error),
225- " Translation to output format failed" );
224+ return errorTo<JITResult> (std::move (Error),
225+ " Translation to output format failed" );
226226 }
227227
228228 FusedKernelInfo.NDR = FusedNDR->getNDR ();
@@ -234,37 +234,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
234234 return JITResult{FusedKernelInfo};
235235}
236236
237- extern " C" KF_EXPORT_SYMBOL JITResult
237+ extern " C" KF_EXPORT_SYMBOL RTCResult
238238compileSYCL (InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
239239 View<const char *> UserArgs) {
240240 auto UserArgListOrErr = parseUserArgs (UserArgs);
241241 if (!UserArgListOrErr) {
242- return errorToFusionResult (UserArgListOrErr.takeError (),
243- " Parsing of user arguments failed" );
242+ return errorTo<RTCResult> (UserArgListOrErr.takeError (),
243+ " Parsing of user arguments failed" );
244244 }
245245 llvm::opt::InputArgList UserArgList = std::move (*UserArgListOrErr);
246246
247247 auto ModuleOrErr = compileDeviceCode (SourceFile, IncludeFiles, UserArgList);
248248 if (!ModuleOrErr) {
249- return errorToFusionResult (ModuleOrErr.takeError (),
250- " Device compilation failed" );
249+ return errorTo<RTCResult> (ModuleOrErr.takeError (),
250+ " Device compilation failed" );
251251 }
252252
253253 std::unique_ptr<llvm::LLVMContext> Context;
254254 std::unique_ptr<llvm::Module> Module = std::move (*ModuleOrErr);
255255 Context.reset (&Module->getContext ());
256256
257257 if (auto Error = linkDeviceLibraries (*Module, UserArgList)) {
258- return errorToFusionResult (std::move (Error), " Device linking failed" );
258+ return errorTo<RTCResult> (std::move (Error), " Device linking failed" );
259259 }
260260
261- SYCLKernelInfo Kernel;
262- if (auto Error = translation::KernelTranslator::translateKernel (
263- Kernel, *Module, JITContext::getInstance (), BinaryFormat::SPIRV)) {
264- return errorToFusionResult (std::move (Error), " SPIR-V translation failed" );
261+ auto BundleInfoOrError = performPostLink (*Module, UserArgList);
262+ if (!BundleInfoOrError) {
263+ return errorTo<RTCResult>(BundleInfoOrError.takeError (),
264+ " Post-link phase failed" );
265+ }
266+ auto BundleInfo = std::move (*BundleInfoOrError);
267+
268+ auto BinaryInfoOrError =
269+ translation::KernelTranslator::translateBundleToSPIRV (
270+ *Module, JITContext::getInstance ());
271+ if (!BinaryInfoOrError) {
272+ return errorTo<RTCResult>(BinaryInfoOrError.takeError (),
273+ " SPIR-V translation failed" );
265274 }
275+ BundleInfo.BinaryInfo = std::move (*BinaryInfoOrError);
266276
267- return JITResult{Kernel };
277+ return RTCResult{ std::move (BundleInfo) };
268278}
269279
270280extern " C" KF_EXPORT_SYMBOL void resetJITConfiguration () {
0 commit comments