Skip to content

Commit 01837c4

Browse files
allightcopybara-github
authored andcommitted
Allow aot_compiler_main to avoid compiling anything if only the protobuf descriptor is requested.
This lets us speed up compilation somewhat and enable LSPs to build jit wrappers without having to fully LLVM optimize/codegen the module. PiperOrigin-RevId: 860324414
1 parent d3e2acd commit 01837c4

File tree

10 files changed

+135
-36
lines changed

10 files changed

+135
-36
lines changed

xls/build_rules/xls_internal_aot_rules.bzl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,31 @@ def _xls_aot_generate_impl(ctx):
102102
proto_file = ctx.actions.declare_file(out_proto_filename)
103103
obj_file = ctx.actions.declare_file(out_obj_filename)
104104
args = ctx.actions.args()
105-
args.add("-input", src.ir_file.path)
105+
skeleton_args = ctx.actions.args()
106+
107+
def common_add(*va, **kwargs):
108+
args.add(*va, **kwargs)
109+
skeleton_args.add(*va, **kwargs)
110+
111+
common_add("-input", src.ir_file.path)
106112
if (ctx.attr.salt_symbols):
107-
args.add("-symbol_salt", str(ctx.label))
108-
args.add("-top", ctx.attr.top)
113+
common_add("-symbol_salt", str(ctx.label))
114+
common_add("-aot_target", ctx.attr.aot_target)
115+
common_add("-top", ctx.attr.top)
116+
other_linking_contexts = []
117+
if ctx.attr.with_msan:
118+
common_add("--include_msan=true")
119+
120+
# With msan we need the TLS implementation.
121+
other_linking_contexts = [ctx.attr._jit_emulated_tls[CcInfo].linking_context]
122+
else:
123+
common_add("--include_msan=false")
124+
109125
args.add("-output_object", obj_file.path)
110-
args.add("-output_proto", proto_file.path)
111126
args.add("-llvm_opt_level", ctx.attr.llvm_opt_level)
112-
args.add("--aot_target", ctx.attr.aot_target)
127+
128+
skeleton_args.add("-output_proto", proto_file.path)
129+
113130
extra_files = []
114131
aot_direct_request = ctx.attr._emit_aot_intermediates[BuildSettingInfo].value
115132
save_temps_reqest = ctx.attr._save_temps_is_requested[BoolConfigSettingInfo].value
@@ -124,23 +141,28 @@ def _xls_aot_generate_impl(ctx):
124141
args.add("-output_llvm_opt_ir", llvm_opt_ir_file.path)
125142
args.add("-output_asm", asm_file.path)
126143

127-
other_linking_contexts = []
128-
if ctx.attr.with_msan:
129-
args.add("--include_msan=true")
130-
131-
# With msan we need the TLS implementation.
132-
other_linking_contexts = [ctx.attr._jit_emulated_tls[CcInfo].linking_context]
133-
else:
134-
args.add("--include_msan=false")
144+
# Non-skeleton run to create the object file.
135145
ctx.actions.run(
136-
outputs = [proto_file, obj_file] + extra_files,
146+
outputs = [obj_file] + extra_files,
137147
inputs = [src.ir_file],
138148
arguments = [args],
139149
executable = aot_compiler,
140150
mnemonic = "AOTCompiling",
141151
progress_message = "Aot(JIT)Compiling %{label}: %{input}",
142152
toolchain = None,
143153
)
154+
155+
# Skeleton run to create the proto file.
156+
ctx.actions.run(
157+
outputs = [proto_file],
158+
inputs = [src.ir_file],
159+
arguments = [skeleton_args],
160+
executable = aot_compiler,
161+
mnemonic = "AotSkeleton",
162+
progress_message = "Generating AOT skeleton %{label}: %{input}",
163+
toolchain = None,
164+
)
165+
144166
obj_file_outputs = cc_common.create_compilation_outputs(
145167
objects = depset([obj_file]),
146168
pic_objects = depset([obj_file]),

xls/jit/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ cc_binary(
122122
"//xls/common/status:ret_check",
123123
"//xls/common/status:status_macros",
124124
"//xls/dev_tools:extract_interface",
125+
"//xls/interpreter:evaluator_options",
125126
"//xls/ir",
126127
"//xls/ir:block_elaboration",
127128
"//xls/ir:ir_parser",
@@ -748,6 +749,7 @@ cc_library(
748749
":jit_evaluator_options",
749750
":llvm_compiler",
750751
":observer",
752+
"//xls/common/status:ret_check",
751753
"//xls/common/status:status_macros",
752754
"@com_google_absl//absl/flags:flag",
753755
"@com_google_absl//absl/status",

xls/jit/aot_compiler.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414

1515
#include "xls/jit/aot_compiler.h"
1616

17+
#include <cstdint>
1718
#include <memory>
1819
#include <string>
1920
#include <utility>
21+
#include <vector>
2022

2123
#include "absl/flags/flag.h"
2224
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
2326
#include "absl/strings/str_cat.h"
2427
#include "llvm/include/llvm/ADT/SmallVector.h"
2528
#include "llvm/include/llvm/ADT/StringRef.h"
@@ -46,6 +49,7 @@
4649
#include "llvm/include/llvm/TargetParser/Triple.h"
4750
#include "llvm/include/llvm/TargetParser/X86TargetParser.h"
4851
#include "llvm/include/llvm/Transforms/Utils/Cloning.h"
52+
#include "xls/common/status/ret_check.h"
4953
#include "xls/common/status/status_macros.h"
5054
#include "xls/jit/jit_emulated_tls.h"
5155
#include "xls/jit/jit_evaluator_options.h"
@@ -201,6 +205,16 @@ absl::Status AddWeakEmuTls(llvm::Module& module, llvm::LLVMContext* context) {
201205
absl::Status AotCompiler::CompileModule(
202206
std::unique_ptr<llvm::Module>&& module) {
203207
JitObserverRequests notification;
208+
if (jit_options_.generate_skeleton()) {
209+
// No need to actually compile anything.
210+
object_code_.emplace();
211+
XLS_RET_CHECK(jit_options_.jit_observer() == nullptr ||
212+
!jit_options_.jit_observer()
213+
->GetNotificationOptions()
214+
.has_any_requests())
215+
<< "skeleton not compatible with observers";
216+
return absl::OkStatus();
217+
}
204218
if (jit_options_.jit_observer() != nullptr) {
205219
notification = jit_options_.jit_observer()->GetNotificationOptions();
206220
}

xls/jit/aot_compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class AotCompiler final : public LlvmCompiler {
6262
return *object_code_;
6363
}
6464

65+
bool is_skeleton() const override { return jit_options_.generate_skeleton(); }
66+
6567
protected:
6668
absl::Status InitInternal() override;
6769

xls/jit/aot_compiler_main.cc

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "xls/common/status/ret_check.h"
4040
#include "xls/common/status/status_macros.h"
4141
#include "xls/dev_tools/extract_interface.h"
42+
#include "xls/interpreter/evaluator_options.h"
4243
#include "xls/ir/block_elaboration.h"
4344
#include "xls/ir/function.h"
4445
#include "xls/ir/function_base.h"
@@ -59,6 +60,18 @@
5960
#include "xls/jit/type_buffer_metadata.h"
6061
#include "xls/jit/type_layout.pb.h"
6162

63+
static constexpr std::string_view kUsage = R"(
64+
aot_compiler <flags>
65+
66+
Compile the given IR file to an object file and a proto file.
67+
68+
If --output_proto/--output_textproto is given and all of --output_object,
69+
--output_llvm_ir, --output_llvm_opt_ir, and --output_asm are not given, the
70+
compilation will not be performed and only the proto will be generated. This is
71+
significantly faster than performing the compilation and can be used to quickly
72+
check the output of the compiler.
73+
)";
74+
6275
ABSL_FLAG(std::string, input, "", "Path to the IR to compile.");
6376
ABSL_FLAG(std::string, symbol_salt, "",
6477
"Additional text to append to symbol names to ensure no collisions.");
@@ -264,11 +277,14 @@ absl::Status RealMain(const std::string& input_ir_path,
264277
}
265278

266279
std::optional<JitObjectCode> object_code;
280+
bool generate_skeleton = !output_object_path && !output_llvm_ir_path &&
281+
!output_llvm_opt_ir_path && !output_asm_path;
267282
JitEvaluatorOptions jit_opts;
268283
jit_opts.set_opt_level(llvm_opt_level)
269284
.set_jit_observer(&obs)
270285
.set_symbol_salt(absl::GetFlag(FLAGS_symbol_salt))
271-
.set_include_msan(include_msan);
286+
.set_include_msan(include_msan)
287+
.set_generate_skeleton(generate_skeleton);
272288
if (f->IsFunction()) {
273289
XLS_ASSIGN_OR_RETURN(
274290
object_code, FunctionJit::CreateObjectCode(
@@ -288,33 +304,35 @@ absl::Status RealMain(const std::string& input_ir_path,
288304
XLS_ASSIGN_OR_RETURN(object_code,
289305
BlockJit::CreateObjectCode(elab, jit_opts));
290306
}
291-
AotPackageEntrypointsProto all_entrypoints;
292307
if (output_object_path) {
293308
XLS_RETURN_IF_ERROR(SetFileContents(
294309
*output_object_path, std::string(object_code->object_code.begin(),
295310
object_code->object_code.end())));
296311
}
297312

298-
*all_entrypoints.mutable_data_layout() =
299-
object_code->data_layout.getStringRepresentation();
313+
if (output_proto_path || output_textproto_path) {
314+
AotPackageEntrypointsProto all_entrypoints;
315+
*all_entrypoints.mutable_data_layout() =
316+
object_code->data_layout.getStringRepresentation();
300317

301-
auto context = std::make_unique<llvm::LLVMContext>();
302-
LlvmTypeConverter type_converter(context.get(), object_code->data_layout);
303-
for (const FunctionEntrypoint& oc : object_code->entrypoints) {
304-
XLS_ASSIGN_OR_RETURN(
305-
*all_entrypoints.add_entrypoint(),
306-
GenerateEntrypointProto(
307-
object_code->package ? object_code->package.get() : package.get(),
308-
oc, include_msan, type_converter));
309-
}
310-
if (output_textproto_path) {
311-
std::string text;
312-
XLS_RET_CHECK(google::protobuf::TextFormat::PrintToString(all_entrypoints, &text));
313-
XLS_RETURN_IF_ERROR(SetFileContents(*output_textproto_path, text));
314-
}
315-
if (output_proto_path) {
316-
XLS_RETURN_IF_ERROR(SetFileContents(*output_proto_path,
317-
all_entrypoints.SerializeAsString()));
318+
auto context = std::make_unique<llvm::LLVMContext>();
319+
LlvmTypeConverter type_converter(context.get(), object_code->data_layout);
320+
for (const FunctionEntrypoint& oc : object_code->entrypoints) {
321+
XLS_ASSIGN_OR_RETURN(
322+
*all_entrypoints.add_entrypoint(),
323+
GenerateEntrypointProto(
324+
object_code->package ? object_code->package.get() : package.get(),
325+
oc, include_msan, type_converter));
326+
}
327+
if (output_textproto_path) {
328+
std::string text;
329+
XLS_RET_CHECK(google::protobuf::TextFormat::PrintToString(all_entrypoints, &text));
330+
XLS_RETURN_IF_ERROR(SetFileContents(*output_textproto_path, text));
331+
}
332+
if (output_proto_path) {
333+
XLS_RETURN_IF_ERROR(SetFileContents(*output_proto_path,
334+
all_entrypoints.SerializeAsString()));
335+
}
318336
}
319337
if (output_llvm_ir_path) {
320338
XLS_RETURN_IF_ERROR(
@@ -335,7 +353,7 @@ absl::Status RealMain(const std::string& input_ir_path,
335353
} // namespace xls
336354

337355
int main(int argc, char** argv) {
338-
xls::InitXls(argv[0], argc, argv);
356+
xls::InitXls(kUsage, argc, argv);
339357
std::string input_ir_path = absl::GetFlag(FLAGS_input);
340358
QCHECK(!input_ir_path.empty()) << "--input must be specified.";
341359

xls/jit/function_base_jit.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,15 @@ absl::StatusOr<PartitionedFunction> BuildFunctionInternal(
853853
.name = "continuation_point",
854854
.type = llvm::Type::getInt64Ty(jit_context.context())});
855855

856+
// NB Ideally we would be able to just immediately return here and not
857+
// generate any of the llvm ir code or partition functions. Unfortunately we
858+
// need to do that to figure out how big some of the buffers need to be. We
859+
// still skip actually doing anything with the llvm bitcode we generate so the
860+
// skeleton is still significantly faster than a full aot compile.
861+
// TODO(https://github.com/google/xls/issues/3724): This shouldn't be needed.
862+
// Ideally we should be able to determine the buffer sizes without writing out
863+
// the llvm ir and return immediately.
864+
856865
XLS_RETURN_IF_ERROR(AllocateBuffers(partitions, wrapper, allocator));
857866

858867
std::vector<llvm::Function*> partition_functions;
@@ -1174,6 +1183,15 @@ absl::StatusOr<llvm::Function*> BuildPackedWrapper(
11741183
LlvmFunctionWrapper::FunctionArg{
11751184
.name = "continuation_point",
11761185
.type = llvm::Type::getInt64Ty(*context)});
1186+
if (jit_context.is_skeleton()) {
1187+
// No need to actually generate any code for a skeleton compile
1188+
// Add a ret to ensure the code is well formed (though since we don't
1189+
// actually try to do anything with it this isn't necessary it still keeps
1190+
// it cleaner).
1191+
wrapper.entry_builder().CreateRet(
1192+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(*context), /*value=*/0));
1193+
return wrapper.function();
1194+
}
11771195

11781196
// First load and unpack the arguments then store them in LLVM native data
11791197
// layout. These unpacked values are pointed to by an array of pointers passed
@@ -1333,6 +1351,8 @@ absl::StatusOr<JittedFunctionBase> JittedFunctionBase::BuildInternal(
13331351

13341352
jitted_function.function_name_ = function_name;
13351353
if (jit_context.llvm_compiler().IsOrcJit()) {
1354+
XLS_RET_CHECK(!jit_context.is_skeleton())
1355+
<< "Only AOT can generate skeleton compiles";
13361356
XLS_ASSIGN_OR_RETURN(auto* orc_jit, jit_context.llvm_compiler().AsOrcJit());
13371357
XLS_ASSIGN_OR_RETURN(auto fn_address, orc_jit->LoadSymbol(function_name));
13381358
jitted_function.function_ = absl::bit_cast<JitFunctionType>(fn_address);
@@ -1345,6 +1365,8 @@ absl::StatusOr<JittedFunctionBase> JittedFunctionBase::BuildInternal(
13451365
if (build_packed_wrapper) {
13461366
jitted_function.packed_function_name_ = packed_wrapper_name;
13471367
if (jit_context.llvm_compiler().IsOrcJit()) {
1368+
XLS_RET_CHECK(!jit_context.is_skeleton())
1369+
<< "Only AOT can generate skeleton compiles";
13481370
XLS_ASSIGN_OR_RETURN(auto* orc_jit,
13491371
jit_context.llvm_compiler().AsOrcJit());
13501372
XLS_ASSIGN_OR_RETURN(auto packed_fn_address,

xls/jit/ir_builder_visitor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class JitBuilderContext {
6161
LlvmCompiler& llvm_compiler() { return llvm_compiler_; }
6262
LlvmTypeConverter& type_converter() { return type_converter_; }
6363
FunctionBase* top() const { return top_; }
64+
bool is_skeleton() const { return llvm_compiler_.is_skeleton(); }
6465

6566
// Destructively returns the underlying llvm::Module.
6667
std::unique_ptr<llvm::Module> ConsumeModule() { return std::move(module_); }

xls/jit/jit_evaluator_options.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,22 @@ class JitEvaluatorOptions {
6666
}
6767
JitObserver* jit_observer() const { return jit_observer_; }
6868

69+
// Tell the JIT to generate skeleton object code. Don't actually compile
70+
// anything but just create the symbols we would generate if we were doing a
71+
// full compile.
72+
JitEvaluatorOptions& set_generate_skeleton(bool value) {
73+
generate_skeleton_ = value;
74+
return *this;
75+
}
76+
bool generate_skeleton() const { return generate_skeleton_; }
77+
6978
private:
7079
int64_t opt_level_ = LlvmCompiler::kDefaultOptLevel;
7180
std::string symbol_salt_;
7281
bool include_observer_callbacks_ = false;
7382
bool include_msan_ = false;
7483
JitObserver* jit_observer_ = nullptr;
84+
bool generate_skeleton_ = false;
7585
};
7686

7787
} // namespace xls

xls/jit/llvm_compiler.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class LlvmCompiler {
9292
return include_observer_callbacks_;
9393
}
9494

95+
// Return true if this is a skeleton compilation. That is don't actually
96+
// compile anything just create the symbols.
97+
virtual bool is_skeleton() const { return false; }
98+
9599
protected:
96100
absl::Status Init();
97101

xls/jit/observer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ struct JitObserverRequests {
4545
bool optimized_module = false;
4646
// Do we want to get called with optimized asm code.
4747
bool assembly_code_str = false;
48+
49+
bool has_any_requests() const {
50+
return unoptimized_module || optimized_module || assembly_code_str;
51+
}
4852
};
4953

5054
// Basic observer for JIT compilation events

0 commit comments

Comments
 (0)