Skip to content

Commit 73568e8

Browse files
committed
Restore early release of threadsafe context for parallel compilation
by creating a manual co-routine to get out of the caller scope.
1 parent d08d3ed commit 73568e8

File tree

1 file changed

+165
-167
lines changed

1 file changed

+165
-167
lines changed

src/aotcompile.cpp

Lines changed: 165 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,63 +1852,142 @@ static inline void schedule_uv_thread(uv_thread_t *worker, CB &&cb)
18521852

18531853
// Entrypoint to optionally-multithreaded image compilation. This handles global coordination of the threading,
18541854
// as well as partitioning, serialization, and deserialization.
1855-
template<typename ModuleReleasedFunc>
1856-
static SmallVector<AOTOutputs, 16> add_output(Module &M, TargetMachine &TM, StringRef name, unsigned threads,
1857-
bool unopt_out, bool opt_out, bool obj_out, bool asm_out, ModuleReleasedFunc module_released) {
1858-
SmallVector<AOTOutputs, 16> outputs(threads);
1859-
assert(threads);
1860-
assert(unopt_out || opt_out || obj_out || asm_out);
1861-
// Timers for timing purposes
1862-
TimerGroup timer_group("add_output", ("Time to optimize and emit LLVM module " + name).str());
1863-
SmallVector<ShardTimers, 1> timers(threads);
1864-
for (unsigned i = 0; i < threads; ++i) {
1865-
auto idx = std::to_string(i);
1866-
timers[i].name = "shard_" + idx;
1867-
timers[i].desc = ("Timings for " + name + " module shard " + idx).str();
1868-
timers[i].deserialize.init("deserialize_" + idx, "Deserialize module");
1869-
timers[i].materialize.init("materialize_" + idx, "Materialize declarations");
1870-
timers[i].construct.init("construct_" + idx, "Construct partitioned definitions");
1871-
timers[i].unopt.init("unopt_" + idx, "Emit unoptimized bitcode");
1872-
timers[i].optimize.init("optimize_" + idx, "Optimize shard");
1873-
timers[i].opt.init("opt_" + idx, "Emit optimized bitcode");
1874-
timers[i].obj.init("obj_" + idx, "Emit object file");
1875-
timers[i].asm_.init("asm_" + idx, "Emit assembly file");
1876-
}
1877-
Timer partition_timer("partition", "Partition module", timer_group);
1878-
Timer serialize_timer("serialize", "Serialize module", timer_group);
1879-
Timer output_timer("output", "Add outputs", timer_group);
1880-
bool report_timings = false;
1881-
if (auto env = getenv("JULIA_IMAGE_TIMINGS")) {
1882-
char *endptr;
1883-
unsigned long val = strtoul(env, &endptr, 10);
1884-
if (endptr != env && !*endptr && val <= 1) {
1885-
report_timings = val;
1886-
} else {
1887-
if (StringRef("true").compare_insensitive(env) == 0)
1888-
report_timings = true;
1889-
else if (StringRef("false").compare_insensitive(env) == 0)
1890-
report_timings = false;
1891-
else
1892-
errs() << "WARNING: Invalid value for JULIA_IMAGE_TIMINGS: " << env << "\n";
1855+
1856+
// This is more or less a manual co-routine version of add_output
1857+
// which allows exiting calling scope when the module is released.
1858+
struct OutputAdder {
1859+
OutputAdder(Module &M, TargetMachine &TM, StringRef name, unsigned threads,
1860+
bool unopt_out, bool opt_out, bool obj_out, bool asm_out)
1861+
: TM(TM), threads(threads), unopt_out(unopt_out),
1862+
opt_out(opt_out), obj_out(obj_out), asm_out(asm_out),
1863+
timer_group("add_output", ("Time to optimize and emit LLVM module " + name).str())
1864+
{
1865+
assert(threads);
1866+
assert(unopt_out || opt_out || obj_out || asm_out);
1867+
for (unsigned i = 0; i < threads; ++i) {
1868+
auto idx = std::to_string(i);
1869+
timers[i].name = "shard_" + idx;
1870+
timers[i].desc = ("Timings for " + name + " module shard " + idx).str();
1871+
timers[i].deserialize.init("deserialize_" + idx, "Deserialize module");
1872+
timers[i].materialize.init("materialize_" + idx, "Materialize declarations");
1873+
timers[i].construct.init("construct_" + idx, "Construct partitioned definitions");
1874+
timers[i].unopt.init("unopt_" + idx, "Emit unoptimized bitcode");
1875+
timers[i].optimize.init("optimize_" + idx, "Optimize shard");
1876+
timers[i].opt.init("opt_" + idx, "Emit optimized bitcode");
1877+
timers[i].obj.init("obj_" + idx, "Emit object file");
1878+
timers[i].asm_.init("asm_" + idx, "Emit assembly file");
1879+
}
1880+
if (auto env = getenv("JULIA_IMAGE_TIMINGS")) {
1881+
char *endptr;
1882+
unsigned long val = strtoul(env, &endptr, 10);
1883+
if (endptr != env && !*endptr && val <= 1) {
1884+
report_timings = val;
1885+
} else {
1886+
if (StringRef("true").compare_insensitive(env) == 0)
1887+
report_timings = true;
1888+
else if (StringRef("false").compare_insensitive(env) == 0)
1889+
report_timings = false;
1890+
else
1891+
errs() << "WARNING: Invalid value for JULIA_IMAGE_TIMINGS: " << env << "\n";
1892+
}
1893+
}
1894+
// Single-threaded case
1895+
if (threads == 1) {
1896+
output_timer.startTimer();
1897+
{
1898+
JL_TIMING(NATIVE_AOT, NATIVE_Opt);
1899+
// convert gvars to the expected offset table format for shard 0
1900+
if (M.getGlobalVariable("jl_gvars")) {
1901+
auto gvars = consume_gv<Constant>(M, "jl_gvars", false);
1902+
Type *T_size = M.getDataLayout().getIntPtrType(M.getContext());
1903+
emit_offset_table(M, T_size, gvars, "jl_gvar", "_0"); // module flag "julia.mv.suffix"
1904+
M.getGlobalVariable("jl_gvar_idxs")->setName("jl_gvar_idxs_0");
1905+
}
1906+
output0 = add_output_impl(M, TM, timers[0], unopt_out, opt_out, obj_out, asm_out);
1907+
}
1908+
output_timer.stopTimer();
1909+
return;
18931910
}
1911+
1912+
partition_timer.startTimer();
1913+
uint64_t counter = 0;
1914+
// Partitioning requires all globals to have names.
1915+
// We use a prefix to avoid name conflicts with user code.
1916+
for (auto &G : M.global_values()) {
1917+
if (!G.isDeclaration() && !G.hasName()) {
1918+
G.setName("jl_ext_" + Twine(counter++));
1919+
}
1920+
}
1921+
partitions = partitionModule(M, threads);
1922+
partition_timer.stopTimer();
1923+
1924+
serialize_timer.startTimer();
1925+
serialized = serializeModule(M);
1926+
serialize_timer.stopTimer();
18941927
}
1895-
// Single-threaded case
1896-
if (threads == 1) {
1928+
1929+
auto finish()
1930+
{
1931+
SmallVector<AOTOutputs, 16> outputs(threads);
1932+
if (threads == 1) {
1933+
outputs[0] = std::move(output0);
1934+
if (!report_timings) {
1935+
timer_group.clear();
1936+
} else {
1937+
timer_group.print(dbgs(), true);
1938+
for (auto &t : timers) {
1939+
t.print(dbgs(), true);
1940+
}
1941+
}
1942+
return outputs;
1943+
}
18971944
output_timer.startTimer();
1945+
1946+
// Start all of the worker threads
18981947
{
18991948
JL_TIMING(NATIVE_AOT, NATIVE_Opt);
1900-
// convert gvars to the expected offset table format for shard 0
1901-
if (M.getGlobalVariable("jl_gvars")) {
1902-
auto gvars = consume_gv<Constant>(M, "jl_gvars", false);
1903-
Type *T_size = M.getDataLayout().getIntPtrType(M.getContext());
1904-
emit_offset_table(M, T_size, gvars, "jl_gvar", "_0"); // module flag "julia.mv.suffix"
1905-
M.getGlobalVariable("jl_gvar_idxs")->setName("jl_gvar_idxs_0");
1949+
std::vector<uv_thread_t> workers(threads);
1950+
for (unsigned i = 0; i < threads; i++) {
1951+
schedule_uv_thread(&workers[i], [&, i]() {
1952+
LLVMContext ctx;
1953+
ctx.setDiscardValueNames(true);
1954+
// Lazily deserialize the entire module
1955+
timers[i].deserialize.startTimer();
1956+
auto EM = getLazyBitcodeModule(MemoryBufferRef(StringRef(serialized.data(), serialized.size()), "Optimized"), ctx);
1957+
// Make sure this also fails with only julia, but not LLVM assertions enabled,
1958+
// otherwise, the first error we hit is the LLVM module verification failure,
1959+
// which will look very confusing, because the module was partially deserialized.
1960+
bool deser_succeeded = (bool)EM;
1961+
auto M = cantFail(std::move(EM), "Error loading module");
1962+
assert(deser_succeeded); (void)deser_succeeded;
1963+
timers[i].deserialize.stopTimer();
1964+
1965+
timers[i].materialize.startTimer();
1966+
materializePreserved(*M, partitions[i]);
1967+
timers[i].materialize.stopTimer();
1968+
1969+
timers[i].construct.startTimer();
1970+
std::string suffix = "_" + std::to_string(i);
1971+
construct_vars(*M, partitions[i], suffix);
1972+
M->setModuleFlag(Module::Error, "julia.mv.suffix", MDString::get(M->getContext(), suffix));
1973+
// The DICompileUnit file is not used for anything, but ld64 requires it be a unique string per object file
1974+
// or it may skip emitting debug info for that file. Here set it to ./julia#N
1975+
DIFile *topfile = DIFile::get(M->getContext(), "julia#" + std::to_string(i), ".");
1976+
if (M->getNamedMetadata("llvm.dbg.cu"))
1977+
for (auto CU: M->getNamedMetadata("llvm.dbg.cu")->operands())
1978+
CU->replaceOperandWith(0, topfile);
1979+
timers[i].construct.stopTimer();
1980+
1981+
outputs[i] = add_output_impl(*M, TM, timers[i], unopt_out, opt_out, obj_out, asm_out);
1982+
});
19061983
}
1907-
outputs[0] = add_output_impl(M, TM, timers[0], unopt_out, opt_out, obj_out, asm_out);
1984+
1985+
// Wait for all of the worker threads to finish
1986+
for (unsigned i = 0; i < threads; i++)
1987+
uv_thread_join(&workers[i]);
19081988
}
1989+
19091990
output_timer.stopTimer();
1910-
// Don't need M anymore
1911-
module_released(M);
19121991

19131992
if (!report_timings) {
19141993
timer_group.clear();
@@ -1917,97 +1996,37 @@ static SmallVector<AOTOutputs, 16> add_output(Module &M, TargetMachine &TM, Stri
19171996
for (auto &t : timers) {
19181997
t.print(dbgs(), true);
19191998
}
1999+
dbgs() << "Partition weights: [";
2000+
bool comma = false;
2001+
for (auto &p : partitions) {
2002+
if (comma)
2003+
dbgs() << ", ";
2004+
else
2005+
comma = true;
2006+
dbgs() << p.weight;
2007+
}
2008+
dbgs() << "]\n";
19202009
}
19212010
return outputs;
19222011
}
19232012

1924-
partition_timer.startTimer();
1925-
uint64_t counter = 0;
1926-
// Partitioning requires all globals to have names.
1927-
// We use a prefix to avoid name conflicts with user code.
1928-
for (auto &G : M.global_values()) {
1929-
if (!G.isDeclaration() && !G.hasName()) {
1930-
G.setName("jl_ext_" + Twine(counter++));
1931-
}
1932-
}
1933-
auto partitions = partitionModule(M, threads);
1934-
partition_timer.stopTimer();
1935-
1936-
serialize_timer.startTimer();
1937-
auto serialized = serializeModule(M);
1938-
serialize_timer.stopTimer();
1939-
1940-
// Don't need M anymore, since we'll only read from serialized from now on
1941-
module_released(M);
1942-
1943-
output_timer.startTimer();
1944-
1945-
// Start all of the worker threads
1946-
{
1947-
JL_TIMING(NATIVE_AOT, NATIVE_Opt);
1948-
std::vector<uv_thread_t> workers(threads);
1949-
for (unsigned i = 0; i < threads; i++) {
1950-
schedule_uv_thread(&workers[i], [&, i]() {
1951-
LLVMContext ctx;
1952-
ctx.setDiscardValueNames(true);
1953-
// Lazily deserialize the entire module
1954-
timers[i].deserialize.startTimer();
1955-
auto EM = getLazyBitcodeModule(MemoryBufferRef(StringRef(serialized.data(), serialized.size()), "Optimized"), ctx);
1956-
// Make sure this also fails with only julia, but not LLVM assertions enabled,
1957-
// otherwise, the first error we hit is the LLVM module verification failure,
1958-
// which will look very confusing, because the module was partially deserialized.
1959-
bool deser_succeeded = (bool)EM;
1960-
auto M = cantFail(std::move(EM), "Error loading module");
1961-
assert(deser_succeeded); (void)deser_succeeded;
1962-
timers[i].deserialize.stopTimer();
1963-
1964-
timers[i].materialize.startTimer();
1965-
materializePreserved(*M, partitions[i]);
1966-
timers[i].materialize.stopTimer();
1967-
1968-
timers[i].construct.startTimer();
1969-
std::string suffix = "_" + std::to_string(i);
1970-
construct_vars(*M, partitions[i], suffix);
1971-
M->setModuleFlag(Module::Error, "julia.mv.suffix", MDString::get(M->getContext(), suffix));
1972-
// The DICompileUnit file is not used for anything, but ld64 requires it be a unique string per object file
1973-
// or it may skip emitting debug info for that file. Here set it to ./julia#N
1974-
DIFile *topfile = DIFile::get(M->getContext(), "julia#" + std::to_string(i), ".");
1975-
if (M->getNamedMetadata("llvm.dbg.cu"))
1976-
for (auto CU: M->getNamedMetadata("llvm.dbg.cu")->operands())
1977-
CU->replaceOperandWith(0, topfile);
1978-
timers[i].construct.stopTimer();
1979-
1980-
outputs[i] = add_output_impl(*M, TM, timers[i], unopt_out, opt_out, obj_out, asm_out);
1981-
});
1982-
}
1983-
1984-
// Wait for all of the worker threads to finish
1985-
for (unsigned i = 0; i < threads; i++)
1986-
uv_thread_join(&workers[i]);
1987-
}
1988-
1989-
output_timer.stopTimer();
1990-
1991-
if (!report_timings) {
1992-
timer_group.clear();
1993-
} else {
1994-
timer_group.print(dbgs(), true);
1995-
for (auto &t : timers) {
1996-
t.print(dbgs(), true);
1997-
}
1998-
dbgs() << "Partition weights: [";
1999-
bool comma = false;
2000-
for (auto &p : partitions) {
2001-
if (comma)
2002-
dbgs() << ", ";
2003-
else
2004-
comma = true;
2005-
dbgs() << p.weight;
2006-
}
2007-
dbgs() << "]\n";
2008-
}
2009-
return outputs;
2010-
}
2013+
TargetMachine &TM;
2014+
unsigned threads;
2015+
bool unopt_out;
2016+
bool opt_out;
2017+
bool obj_out;
2018+
bool asm_out;
2019+
AOTOutputs output0;
2020+
// Timers for timing purposes
2021+
TimerGroup timer_group;
2022+
SmallVector<ShardTimers, 1> timers{threads};
2023+
Timer partition_timer{"partition", "Partition module", timer_group};
2024+
Timer serialize_timer{"serialize", "Serialize module", timer_group};
2025+
Timer output_timer{"output", "Add outputs", timer_group};
2026+
bool report_timings{false};
2027+
SmallVector<Partition, 32> partitions;
2028+
SmallVector<char, 0> serialized;
2029+
};
20112030

20122031
extern int jl_is_timing_passes;
20132032
static unsigned compute_image_thread_count(const ModuleInfo &info) {
@@ -2145,8 +2164,8 @@ void jl_dump_native_impl(void *native_code,
21452164
OverrideStackAlignment = M.getOverrideStackAlignment();
21462165
});
21472166

2148-
auto compile = [&](Module &M, StringRef name, unsigned threads, auto module_released) {
2149-
return add_output(M, *SourceTM, name, threads, !!unopt_bc_fname, !!bc_fname, !!obj_fname, !!asm_fname, module_released);
2167+
auto start_compile = [&](Module &M, StringRef name, unsigned threads) {
2168+
return OutputAdder(M, *SourceTM, name, threads, !!unopt_bc_fname, !!bc_fname, !!obj_fname, !!asm_fname);
21502169
};
21512170

21522171
SmallVector<AOTOutputs, 16> sysimg_outputs;
@@ -2215,7 +2234,7 @@ void jl_dump_native_impl(void *native_code,
22152234
// Note that we don't set z to null, this allows the check in WRITE_ARCHIVE
22162235
// to function as expected
22172236
// no need to free the module/context, destructor handles that
2218-
sysimg_outputs = compile(sysimgM, "sysimg", 1, [](Module &) {});
2237+
sysimg_outputs = start_compile(sysimgM, "sysimg", 1).finish();
22192238
}
22202239

22212240
const bool imaging_mode = true;
@@ -2314,34 +2333,13 @@ void jl_dump_native_impl(void *native_code,
23142333
});
23152334

23162335
{
2317-
// Don't use withModuleDo here since we delete the TSM midway through
2318-
auto TSCtx = data->M.getContext();
2319-
#if JL_LLVM_VERSION < 210000
2320-
auto lock = TSCtx.getLock();
2321-
auto dataM = data->M.getModuleUnlocked();
2322-
2323-
data_outputs = compile(*dataM, "text", threads, [data, &lock, &TSCtx](Module &) {
2324-
// Delete data when add_output thinks it's done with it
2325-
// Saves memory for use when multithreading
2326-
auto lock2 = std::move(lock);
2327-
delete data;
2328-
// Drop last reference to shared LLVM::Context
2329-
auto TSCtx2 = std::move(TSCtx);
2330-
});
2331-
#else
2332-
TSCtx.withContextDo([&] (LLVMContext*) {
2333-
auto dataM = data->M.getModuleUnlocked();
2334-
data_outputs = compile(*dataM, "text", threads, [data, &TSCtx](Module &) {
2335-
// Delete data when add_output thinks it's done with it
2336-
// Saves memory for use when multithreading
2337-
// Unfortunately, for LLVM 21, this does not release the lock
2338-
// or release the LLVM context.
2339-
delete data;
2340-
// Drop our reference to shared LLVM::Context
2341-
auto TSCtx2 = std::move(TSCtx);
2342-
});
2336+
auto adder = data->M.withModuleDo([&] (auto &dataM) {
2337+
return start_compile(dataM, "text", threads);
23432338
});
2344-
#endif
2339+
// Delete data when add_output thinks it's done with it
2340+
// Saves memory for use when multithreading
2341+
delete data;
2342+
data_outputs = adder.finish();
23452343
}
23462344

23472345
if (params->emit_metadata) {
@@ -2444,7 +2442,7 @@ void jl_dump_native_impl(void *native_code,
24442442
}
24452443

24462444
// no need to free module/context, destructor handles that
2447-
metadata_outputs = compile(metadataM, "data", 1, [](Module &) {});
2445+
metadata_outputs = start_compile(metadataM, "data", 1).finish();
24482446
}
24492447

24502448
{

0 commit comments

Comments
 (0)