Skip to content

Commit 94a14e5

Browse files
committed
lower_physical_ptrs: handle init
1 parent b51f203 commit 94a14e5

File tree

1 file changed

+116
-75
lines changed

1 file changed

+116
-75
lines changed

src/shady/passes/lower_physical_ptrs.c

Lines changed: 116 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
#include "list.h"
1515
#include "dict.h"
1616

17+
#include <string.h>
1718
#include <assert.h>
1819

1920
typedef struct Context_ {
2021
Rewriter rewriter;
2122
const CompilerConfig* config;
2223

24+
Nodes collected[NumAddressSpaces];
25+
2326
struct Dict* serialisation_uniform[NumAddressSpaces];
2427
struct Dict* deserialisation_uniform[NumAddressSpaces];
2528

@@ -31,6 +34,8 @@ typedef struct Context_ {
3134
const Node* fake_shared_memory;
3235
} Context;
3336

37+
static void store_init_data(Context* ctx, AddressSpace as, Nodes collected, BodyBuilder* bb);
38+
3439
// TODO: make this configuration-dependant
3540
static bool is_as_emulated(SHADY_UNUSED Context* ctx, AddressSpace as) {
3641
switch (as) {
@@ -274,7 +279,7 @@ static const Node* gen_serdes_fn(Context* ctx, const Type* element_type, bool un
274279

275280
BodyBuilder* bb = begin_body(a);
276281
const Node* address = bytes_to_words(bb, address_param);
277-
const Node* base = ref_decl_helper(a, *get_emulated_as_word_array(ctx, as));
282+
const Node* base = *get_emulated_as_word_array(ctx, as);
278283
if (ser) {
279284
gen_serialisation(ctx, bb, element_type, base, address, value_param);
280285
fun->payload.fun.body = finish_body(bb, fn_ret(a, (Return) { .fn = fun, .args = empty(a) }));
@@ -286,56 +291,46 @@ static const Node* gen_serdes_fn(Context* ctx, const Type* element_type, bool un
286291
return fun;
287292
}
288293

289-
static const Node* process_let(Context* ctx, const Node* node) {
290-
assert(node->tag == Let_TAG);
291-
IrArena* a = ctx->rewriter.dst_arena;
292-
293-
const Node* tail = rewrite_node(&ctx->rewriter, node->payload.let.tail);
294-
const Node* old_instruction = node->payload.let.instruction;
295-
296-
if (old_instruction->tag == PrimOp_TAG) {
297-
const PrimOp* oprim_op = &old_instruction->payload.prim_op;
298-
switch (oprim_op->op) {
299-
case alloca_subgroup_op:
300-
case alloca_op: error("This needs to be lowered (see setup_stack_frames.c)")
301-
// lowering for either kind of memory accesses is similar
302-
case load_op:
303-
case store_op: {
304-
const Node* old_ptr = oprim_op->operands.nodes[0];
305-
const Type* ptr_type = old_ptr->type;
306-
bool uniform_ptr = deconstruct_qualified_type(&ptr_type);
307-
assert(ptr_type->tag == PtrType_TAG);
308-
if (!is_as_emulated(ctx, ptr_type->payload.ptr_type.address_space))
309-
break;
310-
BodyBuilder* bb = begin_body(a);
311-
312-
const Type* element_type = rewrite_node(&ctx->rewriter, ptr_type->payload.ptr_type.pointed_type);
313-
const Node* pointer_as_offset = rewrite_node(&ctx->rewriter, old_ptr);
314-
const Node* fn = gen_serdes_fn(ctx, element_type, uniform_ptr, oprim_op->op == store_op, ptr_type->payload.ptr_type.address_space);
315-
316-
if (oprim_op->op == load_op) {
317-
return finish_body(bb, let(a, call(a, (Call) {.callee = fn_addr_helper(a, fn), .args = singleton(pointer_as_offset)}), tail));
318-
} else {
319-
const Node* value = rewrite_node(&ctx->rewriter, oprim_op->operands.nodes[1]);
320-
return finish_body(bb, let(a, call(a, (Call) { .callee = fn_addr_helper(a, fn), .args = mk_nodes(a, pointer_as_offset, value) }), tail));
321-
}
322-
SHADY_UNREACHABLE;
323-
}
324-
default: break;
325-
}
326-
}
327-
328-
return let(a, rewrite_node(&ctx->rewriter, old_instruction), tail);
329-
}
330-
331294
static const Node* process_node(Context* ctx, const Node* old) {
332295
const Node* found = search_processed(&ctx->rewriter, old);
333296
if (found) return found;
334297

335298
IrArena* a = ctx->rewriter.dst_arena;
336299

337300
switch (old->tag) {
338-
case Let_TAG: return process_let(ctx, old);
301+
case PrimOp_TAG: {
302+
const PrimOp* oprim_op = &old->payload.prim_op;
303+
switch (oprim_op->op) {
304+
case alloca_subgroup_op:
305+
case alloca_op: error("This needs to be lowered (see setup_stack_frames.c)")
306+
// lowering for either kind of memory accesses is similar
307+
case load_op:
308+
case store_op: {
309+
const Node* old_ptr = oprim_op->operands.nodes[0];
310+
const Type* ptr_type = old_ptr->type;
311+
bool uniform_ptr = deconstruct_qualified_type(&ptr_type);
312+
assert(ptr_type->tag == PtrType_TAG);
313+
if (!is_as_emulated(ctx, ptr_type->payload.ptr_type.address_space))
314+
break;
315+
BodyBuilder* bb = begin_body(a);
316+
317+
const Type* element_type = rewrite_node(&ctx->rewriter, ptr_type->payload.ptr_type.pointed_type);
318+
const Node* pointer_as_offset = rewrite_node(&ctx->rewriter, old_ptr);
319+
const Node* fn = gen_serdes_fn(ctx, element_type, uniform_ptr, oprim_op->op == store_op, ptr_type->payload.ptr_type.address_space);
320+
321+
if (oprim_op->op == load_op) {
322+
Nodes r = bind_instruction(bb, call(a, (Call) {.callee = fn_addr_helper(a, fn), .args = singleton(pointer_as_offset)}));
323+
return yield_values_and_wrap_in_block(bb, r);
324+
} else {
325+
const Node* value = rewrite_node(&ctx->rewriter, oprim_op->operands.nodes[1]);
326+
bind_instruction(bb, call(a, (Call) { .callee = fn_addr_helper(a, fn), .args = mk_nodes(a, pointer_as_offset, value) }));
327+
return yield_values_and_wrap_in_block(bb, empty(a));
328+
}
329+
}
330+
default: break;
331+
}
332+
break;
333+
}
339334
case PtrType_TAG: {
340335
if (is_as_emulated(ctx, old->payload.ptr_type.address_space))
341336
return int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false });
@@ -354,6 +349,20 @@ static const Node* process_node(Context* ctx, const Node* old) {
354349
}
355350
break;
356351
}
352+
case Function_TAG: {
353+
if (strcmp(get_abstraction_name(old), "generated_init") == 0) {
354+
Node *new = recreate_decl_header_identity(&ctx->rewriter, old);
355+
BodyBuilder *bb = begin_body(a);
356+
357+
for (AddressSpace as = 0; as < NumAddressSpaces; as++) {
358+
if (is_as_emulated(ctx, as))
359+
store_init_data(ctx, as, ctx->collected[as], bb);
360+
}
361+
new->payload.fun.body = finish_body(bb, rewrite_node(&ctx->rewriter, old->payload.fun.body));
362+
return new;
363+
}
364+
break;
365+
}
357366
default: break;
358367
}
359368

@@ -363,76 +372,107 @@ static const Node* process_node(Context* ctx, const Node* old) {
363372
KeyHash hash_node(Node**);
364373
bool compare_node(Node**, Node**);
365374

366-
/// Collects all global variables in a specific AS, and creates a record type for them.
367-
static void collect_globals_into_record_type(Context* ctx, Node* global_struct_t, AddressSpace as) {
375+
static Nodes collect_globals(Context* ctx, AddressSpace as) {
368376
IrArena* a = ctx->rewriter.dst_arena;
369-
Module* m = ctx->rewriter.dst_module;
370377
Nodes old_decls = get_module_declarations(ctx->rewriter.src_module);
371-
372-
LARRAY(String, member_names, old_decls.count);
373-
LARRAY(const Type*, member_tys, old_decls.count);
378+
LARRAY(const Type*, collected, old_decls.count);
374379
size_t members_count = 0;
375380

376381
for (size_t i = 0; i < old_decls.count; i++) {
377382
const Node* decl = old_decls.nodes[i];
378383
if (decl->tag != GlobalVariable_TAG) continue;
379384
if (decl->payload.global_variable.address_space != as) continue;
385+
collected[members_count] = decl;
386+
members_count++;
387+
}
388+
389+
return nodes(a, members_count, collected);
390+
}
391+
392+
/// Collects all global variables in a specific AS, and creates a record type for them.
393+
static const Node* make_record_type(Context* ctx, AddressSpace as, Nodes collected) {
394+
IrArena* a = ctx->rewriter.dst_arena;
395+
Module* m = ctx->rewriter.dst_module;
396+
397+
String as_name = get_address_space_name(as);
398+
Node* global_struct_t = nominal_type(m, singleton(annotation(a, (Annotation) { .name = "Generated" })), format_string_arena(a->arena, "globals_physical_%s_t", as_name));
399+
400+
LARRAY(String, member_names, collected.count);
401+
LARRAY(const Type*, member_tys, collected.count);
402+
403+
for (size_t i = 0; i < collected.count; i++) {
404+
const Node* decl = collected.nodes[i];
380405
const Type* type = decl->payload.global_variable.type;
381406

382-
member_tys[members_count] = rewrite_node(&ctx->rewriter, type);
383-
member_names[members_count] = decl->payload.global_variable.name;
407+
member_tys[i] = rewrite_node(&ctx->rewriter, type);
408+
member_names[i] = decl->payload.global_variable.name;
384409

385410
// Turn the old global variable into a pointer (which are also now integers)
386411
const Type* emulated_ptr_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false });
387412
Nodes annotations = rewrite_nodes(&ctx->rewriter, decl->payload.global_variable.annotations);
388-
Node* cnst = constant(ctx->rewriter.dst_module, annotations, emulated_ptr_type, decl->payload.global_variable.name);
413+
Node* new_address = constant(ctx->rewriter.dst_module, annotations, emulated_ptr_type, decl->payload.global_variable.name);
389414

390415
// we need to compute the actual pointer by getting the offset and dividing it
391416
// after lower_memory_layout, optimisations will eliminate this and resolve to a value
392417
BodyBuilder* bb = begin_body(a);
393-
const Node* offset = gen_primop_e(bb, offset_of_op, singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), singleton(size_t_literal(a, members_count)));
418+
const Node* offset = gen_primop_e(bb, offset_of_op, singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), singleton(size_t_literal(a, i)));
394419
// const Node* offset_in_words = bytes_to_words(bb, offset);
395-
cnst->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(offset));
396-
397-
register_processed(&ctx->rewriter, decl, cnst);
420+
new_address->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(offset));
398421

399-
members_count++;
400-
}
401-
402-
// add some dummy thing so we don't end up with a zero-sized thing, which SPIR-V hates
403-
if (members_count == 0) {
404-
member_tys[0] = int32_type(a);
405-
member_names[0] = "dummy";
406-
members_count++;
422+
register_processed(&ctx->rewriter, decl, new_address);
407423
}
408424

409425
const Type* record_t = record_type(a, (RecordType) {
410-
.members = nodes(a, members_count, member_tys),
411-
.names = strings(a, members_count, member_names)
426+
.members = nodes(a, collected.count, member_tys),
427+
.names = strings(a, collected.count, member_names)
412428
});
413429

414430
//return record_t;
415431
global_struct_t->payload.nom_type.body = record_t;
432+
return global_struct_t;
433+
}
434+
435+
static void store_init_data(Context* ctx, AddressSpace as, Nodes collected, BodyBuilder* bb) {
436+
IrArena* oa = ctx->rewriter.src_arena;
437+
IrArena* a = ctx->rewriter.dst_arena;
438+
for (size_t i = 0; i < collected.count; i++) {
439+
const Node* old_decl = collected.nodes[i];
440+
assert(old_decl->tag == GlobalVariable_TAG);
441+
const Node* old_init = old_decl->payload.global_variable.init;
442+
if (old_init) {
443+
const Node* old_store = prim_op_helper(oa, store_op, empty(oa), mk_nodes(oa, ref_decl_helper(oa, old_decl), old_init));
444+
bind_instruction(bb, rewrite_node(&ctx->rewriter, old_store));
445+
}
446+
}
416447
}
417448

418449
static void construct_emulated_memory_array(Context* ctx, AddressSpace as, AddressSpace logical_as) {
419450
IrArena* a = ctx->rewriter.dst_arena;
420451
Module* m = ctx->rewriter.dst_module;
421452
String as_name = get_address_space_name(as);
422-
Nodes annotations = singleton(annotation(a, (Annotation) { .name = "Generated" }));
423453

424-
Node* global_struct_t = nominal_type(m, annotations, format_string_arena(a->arena, "globals_physical_%s_t", as_name));
425-
//global_struct_t->payload.nom_type.body = collect_globals_into_record_type(ctx, as);
426-
collect_globals_into_record_type(ctx, global_struct_t, as);
454+
const Type* word_type = int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false });
455+
const Type* ptr_size_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false });
456+
457+
ctx->collected[as] = collect_globals(ctx, as);
458+
if (ctx->collected[as].count == 0) {
459+
const Type* words_array_type = arr_type(a, (ArrType) {
460+
.element_type = word_type,
461+
.size = NULL
462+
});
463+
*get_emulated_as_word_array(ctx, as) = undef(a, (Undef) { .type = ptr_type(a, (PtrType) { .address_space = logical_as, .pointed_type = words_array_type }) });
464+
return;
465+
}
466+
467+
const Node* global_struct_t = make_record_type(ctx, as, ctx->collected[as]);
468+
469+
Nodes annotations = singleton(annotation(a, (Annotation) { .name = "Generated" }));
427470

428471
// compute the size
429472
BodyBuilder* bb = begin_body(a);
430473
const Node* size_of = gen_primop_e(bb, size_of_op, singleton(type_decl_ref(a, (TypeDeclRef) { .decl = global_struct_t })), empty(a));
431474
const Node* size_in_words = bytes_to_words(bb, size_of);
432475

433-
const Type* word_type = int_type(a, (Int) { .width = a->config.memory.word_size, .is_signed = false });
434-
const Type* ptr_size_type = int_type(a, (Int) { .width = a->config.memory.ptr_size, .is_signed = false });
435-
436476
Node* constant_decl = constant(m, annotations, ptr_size_type, format_string_interned(a, "globals_physical_%s_size", as_name));
437477
constant_decl->payload.constant.instruction = yield_values_and_wrap_in_block(bb, singleton(size_in_words));
438478

@@ -442,7 +482,8 @@ static void construct_emulated_memory_array(Context* ctx, AddressSpace as, Addre
442482
});
443483

444484
Node* words_array = global_var(m, annotations, words_array_type, format_string_arena(a->arena, "addressable_word_memory_%s", as_name), logical_as);
445-
*get_emulated_as_word_array(ctx, as) = words_array;
485+
486+
*get_emulated_as_word_array(ctx, as) = ref_decl_helper(a, words_array);
446487
}
447488

448489
Module* lower_physical_ptrs(const CompilerConfig* config, Module* src) {

0 commit comments

Comments
 (0)