Skip to content

Commit 46b2710

Browse files
committed
Fix ordering of globals in IR
1 parent 604eab1 commit 46b2710

File tree

9 files changed

+26
-20
lines changed

9 files changed

+26
-20
lines changed

src/call.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ void jitc_call_upload(ThreadState *ts) {
679679
memset(data, 0, call->offset_size);
680680

681681
for (uint32_t i = 0; i < call->n_inst; ++i) {
682-
auto it = globals_map.find(GlobalKey(call->inst_hash[i], call->n_inst != 1));
682+
auto it = globals_map.find(GlobalKey(call->inst_hash[i], call->n_inst != 1 ? GlobalType::IndirectCallable : GlobalType::Callable));
683683
if (it == globals_map.end())
684684
jitc_fail("jitc_call_upload(): could not find callable!");
685685

src/cuda_eval.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
227227
put('\n');
228228
put(globals.get() + it.second.start, it.second.length);
229229
put('\n');
230-
if (!it.first.indirect_callable)
230+
if (it.first.type != GlobalType::IndirectCallable)
231231
continue;
232232
it.second.callable_index = ctr++;
233233
}
@@ -245,7 +245,7 @@ void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
245245
fmt("\n.visible .global .align 8 .u64 callables[$u] = {\n",
246246
indirect_callable_count_unique);
247247
for (auto const &it : globals_map) {
248-
if (!it.first.indirect_callable)
248+
if (it.first.type != GlobalType::IndirectCallable)
249249
continue;
250250

251251
fmt(" func_$Q$Q$s\n",
@@ -270,8 +270,10 @@ void jitc_cuda_assemble_func(const CallData *call, uint32_t inst,
270270
(flags & (uint32_t) JitFlag::PrintIR);
271271

272272
if (call->n_inst == 1)
273-
put(".func");
273+
// Marked as weak, in case a forward declaration is assembled after this
274+
put(".weak .func");
274275
else
276+
// Marked as globally visible for OptiX
275277
put(".visible .func");
276278

277279
if (out_size)
@@ -1585,7 +1587,7 @@ void jitc_var_call_assemble_cuda(CallData *call, uint32_t call_reg,
15851587
}
15861588

15871589
if (call->n_inst == 1) {
1588-
put(" .func ");
1590+
put(" .weak .func");
15891591
if (out_size)
15901592
fmt("(.param .align $u .b8 result[$u]) ", out_align, out_size);
15911593
put("func_unique_");

src/cuda_scatter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void jitc_cuda_render_scatter_reduce_bfly_32(const char *tp, const char *op,
6767
const char *op_ftz, uint32_t shiftamt) {
6868
fmt_intrinsic(
6969
".func reduce_$s_$s(.param .u64 ptr, .param .$s value) {\n"
70-
" .reg .b32 %active, %index, %mask_lt, %mask_gt, %peers, \n"
70+
" .reg .b32 %active, %index, %mask_lt, %mask_gt, %peers,\n"
7171
" %peers_lt, %peers_rev, %rank, %rank_bit, %rank_ballot;\n"
7272
" .reg .b64 %ptr, %ptr_shift;\n"
7373
" .reg .$s %q0, %q1;\n"
@@ -151,7 +151,7 @@ void jitc_cuda_render_scatter_reduce_bfly_64(const char *tp, const char *op,
151151
const char *op_ftz, uint32_t shiftamt) {
152152
fmt_intrinsic(
153153
".func reduce_$s_$s(.param .u64 ptr, .param .$s value) {\n"
154-
" .reg .b32 %active, %index, %mask_lt, %mask_gt, %peers, \n"
154+
" .reg .b32 %active, %index, %mask_lt, %mask_gt, %peers,\n"
155155
" %peers_lt, %peers_rev, %rank, %rank_bit, %rank_ballot;\n"
156156
" .reg .b64 %ptr, %ptr_shift;\n"
157157
" .reg .b32 %q0l, %q0h, %q1l, %q1h;\n"

src/eval.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ XXH128_hash_t jitc_assemble_func(const CallData *call, uint32_t inst,
914914
kernel_hash.high64 = 0;
915915
}
916916

917-
if (globals_map.emplace(GlobalKey(kernel_hash, call->n_inst != 1),
917+
if (globals_map.emplace(GlobalKey(kernel_hash, call->n_inst != 1 ? GlobalType::IndirectCallable : GlobalType::Callable),
918918
GlobalValue(globals.size(), kernel_length)).second) {
919919
// Replace '^'s in 'func_^^^..' or '__direct_callable__^^^..' with hash
920920
size_t hash_offset = strchr(buffer.get() + kernel_offset, '^') - buffer.get(),
@@ -957,7 +957,7 @@ XXH128_hash_t jitc_assemble_func(const CallData *call, uint32_t inst,
957957
/// Register a global declaration that will be included in the final program
958958
void jitc_register_global(const char *str) {
959959
size_t length = strlen(str);
960-
if (globals_map.emplace(GlobalKey(XXH128(str, length, 0), false),
960+
if (globals_map.emplace(GlobalKey(XXH128(str, length, 0), GlobalType::Global),
961961
GlobalValue(globals.size(), length)).second)
962962
globals.put(str, length);
963963
}

src/eval.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,25 @@ struct ScheduledGroup {
3434
: size(size), start(start), end(end) { }
3535
};
3636

37+
enum class GlobalType : uint32_t {
38+
IndirectCallable = 0, // Multi-target vcalls, assembled first
39+
Callable = 1, // Single-target vcalls, assembled second
40+
Global = 2 // Other globals (intrinsics, etc.), assembled last
41+
};
42+
3743
struct GlobalKey {
3844
XXH128_hash_t hash;
39-
bool indirect_callable;
45+
GlobalType type;
4046

41-
GlobalKey(XXH128_hash_t hash, bool callable)
42-
: hash(hash), indirect_callable(callable) { }
47+
GlobalKey(XXH128_hash_t hash, GlobalType type)
48+
: hash(hash), type(type) { }
4349

4450
/* Order so that callables are defined before other globals, but don't use
4551
the callable ID itself for ordering (it can be non-deterministic in
4652
programs that use Dr.Jit with parallelization) */
4753
bool operator<(const GlobalKey &v) const {
48-
int callable_key_t = indirect_callable ? 0 : 1,
49-
callable_key_v = v.indirect_callable ? 0 : 1;
50-
return std::tie(callable_key_t, hash.high64, hash.low64) <
51-
std::tie(callable_key_v, v.hash.high64, v.hash.low64);
54+
return std::tie(type, hash.high64, hash.low64) <
55+
std::tie(v.type, v.hash.high64, v.hash.low64);
5256
}
5357
};
5458

src/llvm_eval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void jitc_llvm_assemble(ThreadState *ts, ScheduledGroup group) {
241241
put('\n');
242242
put(globals.get() + it.second.start, it.second.length);
243243
put('\n');
244-
if (!it.first.indirect_callable)
244+
if (it.first.type != GlobalType::IndirectCallable)
245245
continue;
246246
it.second.callable_index = 1 + ctr++;
247247
}

src/llvm_mcjit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ void jitc_llvm_mcjit_compile(void *llvm_module,
131131
symbols[symbol_pos++] = resolve("callables");
132132

133133
for (auto const &kv: globals_map) {
134-
if (!kv.first.indirect_callable)
134+
if (kv.first.type != GlobalType::IndirectCallable)
135135
continue;
136136

137137
char name_buf[38];

src/llvm_orcv2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ void jitc_llvm_orcv2_compile(void *llvm_module,
122122
symbols[symbol_pos++] = resolve("callables");
123123

124124
for (auto const &kv: globals_map) {
125-
if (!kv.first.indirect_callable)
125+
if (kv.first.type != GlobalType::IndirectCallable)
126126
continue;
127127

128128
char name_buf[38];

src/optix_core.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,
324324
bool continuation_callables = jitc_optix_use_continuation_callables();
325325

326326
for (auto const &it : globals_map) {
327-
if (!it.first.indirect_callable)
327+
if (it.first.type != GlobalType::IndirectCallable)
328328
continue;
329329

330330
char *name = (char *) malloc_check(58);

0 commit comments

Comments
 (0)