Skip to content

Commit e8ce925

Browse files
[SYCL-RTC] Use ToolInvocation directly instead of via ClangTool (intel#19922)
We only run SYCL JIT on a single TU at a time, so using `ClangTool` is a bit awkward, as it is primarily used to run the same action across a set of files: https://github.com/intel/llvm/blob/357f96b7e19d8acb972eb2f1fb276dbc6aa2060b/clang/include/clang/Tooling/Tooling.h#L310-L317 Using `ToolInvocation` better matches our scenario of always doing a single clang invocation: https://github.com/intel/llvm/blob/357f96b7e19d8acb972eb2f1fb276dbc6aa2060b/clang/include/clang/Tooling/Tooling.h#L244-L245 Another benefit is that we have more control over the virtual file system which I'm planning to use in a subsequent PR to have the SYCL toolchain headers distributed inside `libsycl-jit.so` and then put into an `llvm::vfs::InMemoryFileSystem` once to be re-used across all compilation queries. I'm also simplifying the inheritance scheme around `clang::ToolAction`. Instead of having both hashing/compiling doing that, I'm providing a single helper that accepts a reference to the `FrontendAction` that can be kept on the caller's stack, reducing the amount of boilerplate helpers necessary, i.e. `RTCToolActionBase`/`GetSourceHashAction`/`GetLLVMModuleAction` before vs. single `SYCLToolchain::Action` after.
1 parent 182779d commit e8ce925

File tree

1 file changed

+102
-138
lines changed

1 file changed

+102
-138
lines changed

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 102 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -177,101 +177,76 @@ class HashPreprocessedAction : public PreprocessorFrontendAction {
177177
bool Executed = false;
178178
};
179179

180-
class RTCToolActionBase : public ToolAction {
181-
public:
182-
// Code adapted from `FrontendActionFactory::runInvocation`.
183-
bool runInvocation(std::shared_ptr<CompilerInvocation> Invocation,
184-
FileManager *Files,
185-
std::shared_ptr<PCHContainerOperations> PCHContainerOps,
186-
DiagnosticConsumer *DiagConsumer) override {
187-
assert(!hasExecuted() && "Action should only be invoked on a single file");
188-
189-
// Create a compiler instance to handle the actual work.
190-
CompilerInstance Compiler(std::move(Invocation), std::move(PCHContainerOps));
191-
Compiler.setFileManager(Files);
192-
// Suppress summary with number of warnings and errors being printed to
193-
// stdout.
194-
Compiler.setVerboseOutputStream(std::make_unique<llvm::raw_null_ostream>());
195-
196-
// Create the compiler's actual diagnostics engine.
197-
Compiler.createDiagnostics(Files->getVirtualFileSystem(), DiagConsumer,
198-
/*ShouldOwnClient=*/false);
199-
if (!Compiler.hasDiagnostics()) {
200-
return false;
201-
}
202-
203-
Compiler.createSourceManager(*Files);
204-
205-
return executeAction(Compiler, Files);
206-
}
207-
208-
virtual ~RTCToolActionBase() = default;
209-
210-
protected:
211-
virtual bool hasExecuted() = 0;
212-
virtual bool executeAction(CompilerInstance &, FileManager *) = 0;
213-
};
214-
215-
class GetSourceHashAction : public RTCToolActionBase {
216-
protected:
217-
bool executeAction(CompilerInstance &CI, FileManager *Files) override {
218-
HashPreprocessedAction HPA;
219-
const bool Success = CI.ExecuteAction(HPA);
220-
Files->clearStatCache();
221-
if (!Success) {
222-
return false;
180+
class SYCLToolchain {
181+
SYCLToolchain() {}
182+
183+
// Similar to FrontendActionFactory, but we don't take ownership of
184+
// `FrontendAction`, nor do we create copies of it as we only perform a single
185+
// `ToolInvocation`.
186+
class Action : public ToolAction {
187+
FrontendAction &FEAction;
188+
189+
public:
190+
Action(FrontendAction &FEAction) : FEAction(FEAction) {}
191+
~Action() override = default;
192+
193+
// Code adapted from `FrontendActionFactory::runInvocation`:
194+
bool runInvocation(std::shared_ptr<CompilerInvocation> Invocation,
195+
FileManager *Files,
196+
std::shared_ptr<PCHContainerOperations> PCHContainerOps,
197+
DiagnosticConsumer *DiagConsumer) override {
198+
// Create a compiler instance to handle the actual work.
199+
CompilerInstance Compiler(std::move(Invocation),
200+
std::move(PCHContainerOps));
201+
Compiler.setFileManager(Files);
202+
// Suppress summary with number of warnings and errors being printed to
203+
// stdout.
204+
Compiler.setVerboseOutputStream(
205+
std::make_unique<llvm::raw_null_ostream>());
206+
207+
// Create the compiler's actual diagnostics engine.
208+
Compiler.createDiagnostics(Files->getVirtualFileSystem(), DiagConsumer,
209+
/*ShouldOwnClient=*/false);
210+
if (!Compiler.hasDiagnostics())
211+
return false;
212+
213+
Compiler.createSourceManager(*Files);
214+
215+
const bool Success = Compiler.ExecuteAction(FEAction);
216+
217+
Files->clearStatCache();
218+
return Success;
223219
}
224-
225-
Hash = HPA.takeHash();
226-
Executed = true;
227-
return true;
228-
}
229-
230-
bool hasExecuted() override { return Executed; }
220+
};
231221

232222
public:
233-
BLAKE3Result<> takeHash() {
234-
assert(Executed);
235-
Executed = false;
236-
return std::move(Hash);
223+
static SYCLToolchain &instance() {
224+
static SYCLToolchain Instance;
225+
return Instance;
237226
}
238227

239-
private:
240-
BLAKE3Result<> Hash;
241-
bool Executed = false;
242-
};
228+
bool run(const std::vector<std::string> &CommandLine,
229+
FrontendAction &FEAction,
230+
IntrusiveRefCntPtr<FileSystem> FSOverlay = nullptr,
231+
DiagnosticConsumer *DiagConsumer = nullptr) {
232+
auto FS = llvm::makeIntrusiveRefCnt<llvm::vfs::OverlayFileSystem>(
233+
llvm::vfs::getRealFileSystem());
234+
if (FSOverlay)
235+
FS->pushOverlay(FSOverlay);
243236

244-
struct GetLLVMModuleAction : public RTCToolActionBase {
245-
protected:
246-
bool executeAction(CompilerInstance &CI, FileManager *Files) override {
247-
// Ignore `Compiler.getFrontendOpts().ProgramAction` (would be `EmitBC`) and
248-
// create/execute an `EmitLLVMOnlyAction` (= codegen to LLVM module without
249-
// emitting anything) instead.
250-
EmitLLVMOnlyAction ELOA{&Context};
251-
const bool Success = CI.ExecuteAction(ELOA);
252-
Files->clearStatCache();
253-
if (!Success) {
254-
return false;
255-
}
237+
auto Files = llvm::makeIntrusiveRefCnt<clang::FileManager>(
238+
clang::FileSystemOptions{"." /* WorkingDir */}, FS);
256239

257-
// Take the module to extend its lifetime.
258-
Module = ELOA.takeModule();
240+
Action A{FEAction};
241+
ToolInvocation TI{CommandLine, &A, Files.get(),
242+
std::make_shared<PCHContainerOperations>()};
243+
TI.setDiagnosticConsumer(DiagConsumer ? DiagConsumer : &IgnoreDiag);
259244

260-
return true;
261-
}
262-
263-
bool hasExecuted() override { return static_cast<bool>(Module); }
264-
265-
public:
266-
GetLLVMModuleAction(LLVMContext &Context) : Context{Context}, Module{} {}
267-
ModuleUPtr takeModule() {
268-
assert(Module);
269-
return std::move(Module);
245+
return TI.run();
270246
}
271247

272248
private:
273-
LLVMContext &Context;
274-
ModuleUPtr Module;
249+
clang::IgnoringDiagConsumer IgnoreDiag;
275250
};
276251

277252
class ClangDiagnosticWrapper {
@@ -320,9 +295,9 @@ class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
320295

321296
} // anonymous namespace
322297

323-
static void adjustArgs(const InputArgList &UserArgList,
324-
const std::string &DPCPPRoot, BinaryFormat Format,
325-
SmallVectorImpl<std::string> &CommandLine) {
298+
static std::vector<std::string>
299+
createCommandLine(const InputArgList &UserArgList, std::string_view DPCPPRoot,
300+
BinaryFormat Format, std::string_view SourceFilePath) {
326301
DerivedArgList DAL{UserArgList};
327302
const auto &OptTable = getDriverOptTable();
328303
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
@@ -349,36 +324,30 @@ static void adjustArgs(const InputArgList &UserArgList,
349324
for_each(DAL, [&DAL, &ASL](Arg *A) { A->render(DAL, ASL); });
350325
for_each(UserArgList,
351326
[&UserArgList, &ASL](Arg *A) { A->render(UserArgList, ASL); });
327+
328+
std::vector<std::string> CommandLine;
329+
CommandLine.reserve(ASL.size() + 2);
330+
CommandLine.emplace_back((DPCPPRoot + "/bin/clang++").str());
352331
transform(ASL, std::back_inserter(CommandLine),
353332
[](const char *AS) { return std::string{AS}; });
333+
CommandLine.emplace_back(SourceFilePath);
334+
return CommandLine;
354335
}
355336

356-
static void setupTool(ClangTool &Tool, const std::string &DPCPPRoot,
357-
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
358-
DiagnosticConsumer *Consumer) {
359-
Tool.setDiagnosticConsumer(Consumer);
360-
// Suppress message "Error while processing" being printed to stdout.
361-
Tool.setPrintErrorMessage(false);
362-
363-
// Set up in-memory filesystem.
364-
Tool.mapVirtualFile(SourceFile.Path, SourceFile.Contents);
365-
for (const auto &IF : IncludeFiles) {
366-
Tool.mapVirtualFile(IF.Path, IF.Contents);
367-
}
337+
static llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem>
338+
getInMemoryFS(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles) {
339+
auto InMemoryFS = llvm::makeIntrusiveRefCnt<llvm::vfs::InMemoryFileSystem>();
340+
341+
InMemoryFS->setCurrentWorkingDirectory(
342+
*llvm::vfs::getRealFileSystem()->getCurrentWorkingDirectory());
368343

369-
// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
370-
// default by this API.
371-
Tool.clearArgumentsAdjusters();
372-
// Then, modify argv[0] so that the driver picks up the correct SYCL
373-
// environment. We've already set the resource directory above.
374-
Tool.appendArgumentsAdjuster(
375-
[&DPCPPRoot](const CommandLineArguments &Args,
376-
StringRef Filename) -> CommandLineArguments {
377-
(void)Filename;
378-
CommandLineArguments NewArgs = Args;
379-
NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str();
380-
return NewArgs;
381-
});
344+
InMemoryFS->addFile(SourceFile.Path, 0,
345+
llvm::MemoryBuffer::getMemBuffer(SourceFile.Contents));
346+
for (InMemoryFile F : IncludeFiles)
347+
InMemoryFS->addFile(F.Path, 0,
348+
llvm::MemoryBuffer::getMemBuffer(F.Contents));
349+
350+
return InMemoryFS;
382351
}
383352

384353
Expected<std::string> jit_compiler::calculateHash(
@@ -391,20 +360,20 @@ Expected<std::string> jit_compiler::calculateHash(
391360
return createStringError("Could not locate DPCPP root directory");
392361
}
393362

394-
SmallVector<std::string> CommandLine;
395-
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
363+
std::vector<std::string> CommandLine =
364+
createCommandLine(UserArgList, DPCPPRoot, Format, SourceFile.Path);
396365

397-
FixedCompilationDatabase DB{".", CommandLine};
398-
ClangTool Tool{DB, {SourceFile.Path}};
366+
HashPreprocessedAction HashAction;
399367

400-
clang::IgnoringDiagConsumer DiagConsumer;
401-
setupTool(Tool, DPCPPRoot, SourceFile, IncludeFiles, &DiagConsumer);
368+
if (SYCLToolchain::instance().run(CommandLine, HashAction,
369+
getInMemoryFS(SourceFile, IncludeFiles))) {
370+
BLAKE3Result<> SourceHash = HashAction.takeHash();
371+
// Last argument is the source file in the format `rtc_N.cpp` which is
372+
// unique for each query, so drop it:
373+
CommandLine.pop_back();
402374

403-
GetSourceHashAction Action;
404-
if (!Tool.run(&Action)) {
405-
BLAKE3Result<> SourceHash = Action.takeHash();
406-
// The adjusted command line contains the DPCPP root and clang major
407-
// version.
375+
// The command line contains the DPCPP root and clang major version in
376+
// "-resource-dir=<...>" argument.
408377
BLAKE3Result<> CommandLineHash =
409378
BLAKE3::hash(arrayRefFromStringRef(join(CommandLine, ",")));
410379

@@ -413,9 +382,10 @@ Expected<std::string> jit_compiler::calculateHash(
413382
// Make the encoding filesystem-friendly.
414383
std::replace(EncodedHash.begin(), EncodedHash.end(), '/', '-');
415384
return std::move(EncodedHash);
416-
}
417385

418-
return createStringError("Calculating source hash failed");
386+
} else {
387+
return createStringError("Calculating source hash failed");
388+
}
419389
}
420390

421391
Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
@@ -429,23 +399,17 @@ Expected<ModuleUPtr> jit_compiler::compileDeviceCode(
429399
return createStringError("Could not locate DPCPP root directory");
430400
}
431401

432-
SmallVector<std::string> CommandLine;
433-
adjustArgs(UserArgList, DPCPPRoot, Format, CommandLine);
434-
435-
FixedCompilationDatabase DB{".", CommandLine};
436-
ClangTool Tool{DB, {SourceFile.Path}};
437-
402+
EmitLLVMOnlyAction ELOA{&Context};
438403
DiagnosticOptions DiagOpts;
439404
ClangDiagnosticWrapper Wrapper(BuildLog, &DiagOpts);
440405

441-
setupTool(Tool, DPCPPRoot, SourceFile, IncludeFiles, Wrapper.consumer());
442-
443-
GetLLVMModuleAction Action{Context};
444-
if (!Tool.run(&Action)) {
445-
return Action.takeModule();
406+
if (SYCLToolchain::instance().run(
407+
createCommandLine(UserArgList, DPCPPRoot, Format, SourceFile.Path),
408+
ELOA, getInMemoryFS(SourceFile, IncludeFiles), Wrapper.consumer())) {
409+
return ELOA.takeModule();
410+
} else {
411+
return createStringError(BuildLog);
446412
}
447-
448-
return createStringError(BuildLog);
449413
}
450414

451415
// This function is a simplified copy of the device library selection process

0 commit comments

Comments
 (0)