Skip to content

Commit 00fe093

Browse files
committed
Use templated error wrapper directly
Signed-off-by: Julian Oppermann <[email protected]>
1 parent e93076c commit 00fe093

File tree

1 file changed

+19
-28
lines changed

1 file changed

+19
-28
lines changed

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using FusedFunction = helper::FusionHelper::FusedFunction;
2626
using FusedFunctionList = std::vector<FusedFunction>;
2727

2828
template <typename ResultType>
29-
static ResultType wrapError(llvm::Error &&Err, const std::string &Msg) {
29+
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
3030
std::stringstream ErrMsg;
3131
ErrMsg << Msg << "\nDetailed information:\n";
3232
llvm::handleAllErrors(std::move(Err),
@@ -38,15 +38,6 @@ static ResultType wrapError(llvm::Error &&Err, const std::string &Msg) {
3838
return ResultType{ErrMsg.str().c_str()};
3939
}
4040

41-
static JITResult errorToFusionResult(llvm::Error &&Err,
42-
const std::string &Msg) {
43-
return wrapError<JITResult>(std::move(Err), Msg);
44-
}
45-
46-
static RTCResult errorToRTCResult(llvm::Error &&Err, const std::string &Msg) {
47-
return wrapError<RTCResult>(std::move(Err), Msg);
48-
}
49-
5041
static std::vector<jit_compiler::NDRange>
5142
gatherNDRanges(llvm::ArrayRef<SYCLKernelInfo> KernelInformation) {
5243
std::vector<jit_compiler::NDRange> NDRanges;
@@ -104,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
10495
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
10596
ModuleInfo.kernels());
10697
if (auto Error = ModOrError.takeError()) {
107-
return errorToFusionResult(std::move(Error), "Failed to load kernels");
98+
return errorTo<JITResult>(std::move(Error), "Failed to load kernels");
10899
}
109100
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
110101
if (!fusion::FusionPipeline::runMaterializerPasses(
@@ -116,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
116107
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
117108
if (auto Error = translation::KernelTranslator::translateKernel(
118109
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
119-
return errorToFusionResult(std::move(Error),
120-
"Translation to output format failed");
110+
return errorTo<JITResult>(std::move(Error),
111+
"Translation to output format failed");
121112
}
122113

123114
return JITResult{MaterializerKernelInfo};
@@ -142,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
142133
llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
143134
jit_compiler::FusedNDRange::get(NDRanges);
144135
if (llvm::Error Err = FusedNDR.takeError()) {
145-
return errorToFusionResult(std::move(Err), "Illegal ND-range combination");
136+
return errorTo<JITResult>(std::move(Err), "Illegal ND-range combination");
146137
}
147138

148139
if (!isTargetFormatSupported(TargetFormat)) {
@@ -189,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
189180
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
190181
ModuleInfo.kernels());
191182
if (auto Error = ModOrError.takeError()) {
192-
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
183+
return errorTo<JITResult>(std::move(Error), "SPIR-V translation failed");
193184
}
194185
std::unique_ptr<llvm::Module> LLVMMod = std::move(*ModOrError);
195186

@@ -206,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
206197
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
207198
helper::FusionHelper::addFusedKernel(LLVMMod.get(), FusedKernelList);
208199
if (auto Error = NewModOrError.takeError()) {
209-
return errorToFusionResult(std::move(Error),
210-
"Insertion of fused kernel stub failed");
200+
return errorTo<JITResult>(std::move(Error),
201+
"Insertion of fused kernel stub failed");
211202
}
212203
std::unique_ptr<llvm::Module> NewMod = std::move(*NewModOrError);
213204

@@ -230,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
230221

231222
if (auto Error = translation::KernelTranslator::translateKernel(
232223
FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
233-
return errorToFusionResult(std::move(Error),
234-
"Translation to output format failed");
224+
return errorTo<JITResult>(std::move(Error),
225+
"Translation to output format failed");
235226
}
236227

237228
FusedKernelInfo.NDR = FusedNDR->getNDR();
@@ -248,38 +239,38 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
248239
View<const char *> UserArgs) {
249240
auto UserArgListOrErr = parseUserArgs(UserArgs);
250241
if (!UserArgListOrErr) {
251-
return errorToRTCResult(UserArgListOrErr.takeError(),
252-
"Parsing of user arguments failed");
242+
return errorTo<RTCResult>(UserArgListOrErr.takeError(),
243+
"Parsing of user arguments failed");
253244
}
254245
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);
255246

256247
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
257248
if (!ModuleOrErr) {
258-
return errorToRTCResult(ModuleOrErr.takeError(),
259-
"Device compilation failed");
249+
return errorTo<RTCResult>(ModuleOrErr.takeError(),
250+
"Device compilation failed");
260251
}
261252

262253
std::unique_ptr<llvm::LLVMContext> Context;
263254
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
264255
Context.reset(&Module->getContext());
265256

266257
if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
267-
return errorToRTCResult(std::move(Error), "Device linking failed");
258+
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
268259
}
269260

270261
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
271262
if (!BundleInfoOrError) {
272-
return errorToRTCResult(BundleInfoOrError.takeError(),
273-
"Post-link phase failed");
263+
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
264+
"Post-link phase failed");
274265
}
275266
auto BundleInfo = std::move(*BundleInfoOrError);
276267

277268
auto BinaryInfoOrError =
278269
translation::KernelTranslator::translateBundleToSPIRV(
279270
*Module, JITContext::getInstance());
280271
if (!BinaryInfoOrError) {
281-
return errorToRTCResult(BinaryInfoOrError.takeError(),
282-
"SPIR-V translation failed");
272+
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
273+
"SPIR-V translation failed");
283274
}
284275
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);
285276

0 commit comments

Comments
 (0)