14
14
#include "list.h"
15
15
#include "dict.h"
16
16
17
+ #include <string.h>
17
18
#include <assert.h>
18
19
19
20
typedef struct Context_ {
20
21
Rewriter rewriter ;
21
22
const CompilerConfig * config ;
22
23
24
+ Nodes collected [NumAddressSpaces ];
25
+
23
26
struct Dict * serialisation_uniform [NumAddressSpaces ];
24
27
struct Dict * deserialisation_uniform [NumAddressSpaces ];
25
28
@@ -31,6 +34,8 @@ typedef struct Context_ {
31
34
const Node * fake_shared_memory ;
32
35
} Context ;
33
36
37
+ static void store_init_data (Context * ctx , AddressSpace as , Nodes collected , BodyBuilder * bb );
38
+
34
39
// TODO: make this configuration-dependant
35
40
static bool is_as_emulated (SHADY_UNUSED Context * ctx , AddressSpace as ) {
36
41
switch (as ) {
@@ -274,7 +279,7 @@ static const Node* gen_serdes_fn(Context* ctx, const Type* element_type, bool un
274
279
275
280
BodyBuilder * bb = begin_body (a );
276
281
const Node * address = bytes_to_words (bb , address_param );
277
- const Node * base = ref_decl_helper ( a , * get_emulated_as_word_array (ctx , as ) );
282
+ const Node * base = * get_emulated_as_word_array (ctx , as );
278
283
if (ser ) {
279
284
gen_serialisation (ctx , bb , element_type , base , address , value_param );
280
285
fun -> payload .fun .body = finish_body (bb , fn_ret (a , (Return ) { .fn = fun , .args = empty (a ) }));
@@ -286,56 +291,46 @@ static const Node* gen_serdes_fn(Context* ctx, const Type* element_type, bool un
286
291
return fun ;
287
292
}
288
293
289
- static const Node * process_let (Context * ctx , const Node * node ) {
290
- assert (node -> tag == Let_TAG );
291
- IrArena * a = ctx -> rewriter .dst_arena ;
292
-
293
- const Node * tail = rewrite_node (& ctx -> rewriter , node -> payload .let .tail );
294
- const Node * old_instruction = node -> payload .let .instruction ;
295
-
296
- if (old_instruction -> tag == PrimOp_TAG ) {
297
- const PrimOp * oprim_op = & old_instruction -> payload .prim_op ;
298
- switch (oprim_op -> op ) {
299
- case alloca_subgroup_op :
300
- case alloca_op : error ("This needs to be lowered (see setup_stack_frames.c)" )
301
- // lowering for either kind of memory accesses is similar
302
- case load_op :
303
- case store_op : {
304
- const Node * old_ptr = oprim_op -> operands .nodes [0 ];
305
- const Type * ptr_type = old_ptr -> type ;
306
- bool uniform_ptr = deconstruct_qualified_type (& ptr_type );
307
- assert (ptr_type -> tag == PtrType_TAG );
308
- if (!is_as_emulated (ctx , ptr_type -> payload .ptr_type .address_space ))
309
- break ;
310
- BodyBuilder * bb = begin_body (a );
311
-
312
- const Type * element_type = rewrite_node (& ctx -> rewriter , ptr_type -> payload .ptr_type .pointed_type );
313
- const Node * pointer_as_offset = rewrite_node (& ctx -> rewriter , old_ptr );
314
- const Node * fn = gen_serdes_fn (ctx , element_type , uniform_ptr , oprim_op -> op == store_op , ptr_type -> payload .ptr_type .address_space );
315
-
316
- if (oprim_op -> op == load_op ) {
317
- return finish_body (bb , let (a , call (a , (Call ) {.callee = fn_addr_helper (a , fn ), .args = singleton (pointer_as_offset )}), tail ));
318
- } else {
319
- const Node * value = rewrite_node (& ctx -> rewriter , oprim_op -> operands .nodes [1 ]);
320
- return finish_body (bb , let (a , call (a , (Call ) { .callee = fn_addr_helper (a , fn ), .args = mk_nodes (a , pointer_as_offset , value ) }), tail ));
321
- }
322
- SHADY_UNREACHABLE ;
323
- }
324
- default : break ;
325
- }
326
- }
327
-
328
- return let (a , rewrite_node (& ctx -> rewriter , old_instruction ), tail );
329
- }
330
-
331
294
static const Node * process_node (Context * ctx , const Node * old ) {
332
295
const Node * found = search_processed (& ctx -> rewriter , old );
333
296
if (found ) return found ;
334
297
335
298
IrArena * a = ctx -> rewriter .dst_arena ;
336
299
337
300
switch (old -> tag ) {
338
- case Let_TAG : return process_let (ctx , old );
301
+ case PrimOp_TAG : {
302
+ const PrimOp * oprim_op = & old -> payload .prim_op ;
303
+ switch (oprim_op -> op ) {
304
+ case alloca_subgroup_op :
305
+ case alloca_op : error ("This needs to be lowered (see setup_stack_frames.c)" )
306
+ // lowering for either kind of memory accesses is similar
307
+ case load_op :
308
+ case store_op : {
309
+ const Node * old_ptr = oprim_op -> operands .nodes [0 ];
310
+ const Type * ptr_type = old_ptr -> type ;
311
+ bool uniform_ptr = deconstruct_qualified_type (& ptr_type );
312
+ assert (ptr_type -> tag == PtrType_TAG );
313
+ if (!is_as_emulated (ctx , ptr_type -> payload .ptr_type .address_space ))
314
+ break ;
315
+ BodyBuilder * bb = begin_body (a );
316
+
317
+ const Type * element_type = rewrite_node (& ctx -> rewriter , ptr_type -> payload .ptr_type .pointed_type );
318
+ const Node * pointer_as_offset = rewrite_node (& ctx -> rewriter , old_ptr );
319
+ const Node * fn = gen_serdes_fn (ctx , element_type , uniform_ptr , oprim_op -> op == store_op , ptr_type -> payload .ptr_type .address_space );
320
+
321
+ if (oprim_op -> op == load_op ) {
322
+ Nodes r = bind_instruction (bb , call (a , (Call ) {.callee = fn_addr_helper (a , fn ), .args = singleton (pointer_as_offset )}));
323
+ return yield_values_and_wrap_in_block (bb , r );
324
+ } else {
325
+ const Node * value = rewrite_node (& ctx -> rewriter , oprim_op -> operands .nodes [1 ]);
326
+ bind_instruction (bb , call (a , (Call ) { .callee = fn_addr_helper (a , fn ), .args = mk_nodes (a , pointer_as_offset , value ) }));
327
+ return yield_values_and_wrap_in_block (bb , empty (a ));
328
+ }
329
+ }
330
+ default : break ;
331
+ }
332
+ break ;
333
+ }
339
334
case PtrType_TAG : {
340
335
if (is_as_emulated (ctx , old -> payload .ptr_type .address_space ))
341
336
return int_type (a , (Int ) { .width = a -> config .memory .ptr_size , .is_signed = false });
@@ -354,6 +349,20 @@ static const Node* process_node(Context* ctx, const Node* old) {
354
349
}
355
350
break ;
356
351
}
352
+ case Function_TAG : {
353
+ if (strcmp (get_abstraction_name (old ), "generated_init" ) == 0 ) {
354
+ Node * new = recreate_decl_header_identity (& ctx -> rewriter , old );
355
+ BodyBuilder * bb = begin_body (a );
356
+
357
+ for (AddressSpace as = 0 ; as < NumAddressSpaces ; as ++ ) {
358
+ if (is_as_emulated (ctx , as ))
359
+ store_init_data (ctx , as , ctx -> collected [as ], bb );
360
+ }
361
+ new -> payload .fun .body = finish_body (bb , rewrite_node (& ctx -> rewriter , old -> payload .fun .body ));
362
+ return new ;
363
+ }
364
+ break ;
365
+ }
357
366
default : break ;
358
367
}
359
368
@@ -363,76 +372,107 @@ static const Node* process_node(Context* ctx, const Node* old) {
363
372
KeyHash hash_node (Node * * );
364
373
bool compare_node (Node * * , Node * * );
365
374
366
- /// Collects all global variables in a specific AS, and creates a record type for them.
367
- static void collect_globals_into_record_type (Context * ctx , Node * global_struct_t , AddressSpace as ) {
375
+ static Nodes collect_globals (Context * ctx , AddressSpace as ) {
368
376
IrArena * a = ctx -> rewriter .dst_arena ;
369
- Module * m = ctx -> rewriter .dst_module ;
370
377
Nodes old_decls = get_module_declarations (ctx -> rewriter .src_module );
371
-
372
- LARRAY (String , member_names , old_decls .count );
373
- LARRAY (const Type * , member_tys , old_decls .count );
378
+ LARRAY (const Type * , collected , old_decls .count );
374
379
size_t members_count = 0 ;
375
380
376
381
for (size_t i = 0 ; i < old_decls .count ; i ++ ) {
377
382
const Node * decl = old_decls .nodes [i ];
378
383
if (decl -> tag != GlobalVariable_TAG ) continue ;
379
384
if (decl -> payload .global_variable .address_space != as ) continue ;
385
+ collected [members_count ] = decl ;
386
+ members_count ++ ;
387
+ }
388
+
389
+ return nodes (a , members_count , collected );
390
+ }
391
+
392
+ /// Collects all global variables in a specific AS, and creates a record type for them.
393
+ static const Node * make_record_type (Context * ctx , AddressSpace as , Nodes collected ) {
394
+ IrArena * a = ctx -> rewriter .dst_arena ;
395
+ Module * m = ctx -> rewriter .dst_module ;
396
+
397
+ String as_name = get_address_space_name (as );
398
+ Node * global_struct_t = nominal_type (m , singleton (annotation (a , (Annotation ) { .name = "Generated" })), format_string_arena (a -> arena , "globals_physical_%s_t" , as_name ));
399
+
400
+ LARRAY (String , member_names , collected .count );
401
+ LARRAY (const Type * , member_tys , collected .count );
402
+
403
+ for (size_t i = 0 ; i < collected .count ; i ++ ) {
404
+ const Node * decl = collected .nodes [i ];
380
405
const Type * type = decl -> payload .global_variable .type ;
381
406
382
- member_tys [members_count ] = rewrite_node (& ctx -> rewriter , type );
383
- member_names [members_count ] = decl -> payload .global_variable .name ;
407
+ member_tys [i ] = rewrite_node (& ctx -> rewriter , type );
408
+ member_names [i ] = decl -> payload .global_variable .name ;
384
409
385
410
// Turn the old global variable into a pointer (which are also now integers)
386
411
const Type * emulated_ptr_type = int_type (a , (Int ) { .width = a -> config .memory .ptr_size , .is_signed = false });
387
412
Nodes annotations = rewrite_nodes (& ctx -> rewriter , decl -> payload .global_variable .annotations );
388
- Node * cnst = constant (ctx -> rewriter .dst_module , annotations , emulated_ptr_type , decl -> payload .global_variable .name );
413
+ Node * new_address = constant (ctx -> rewriter .dst_module , annotations , emulated_ptr_type , decl -> payload .global_variable .name );
389
414
390
415
// we need to compute the actual pointer by getting the offset and dividing it
391
416
// after lower_memory_layout, optimisations will eliminate this and resolve to a value
392
417
BodyBuilder * bb = begin_body (a );
393
- const Node * offset = gen_primop_e (bb , offset_of_op , singleton (type_decl_ref (a , (TypeDeclRef ) { .decl = global_struct_t })), singleton (size_t_literal (a , members_count )));
418
+ const Node * offset = gen_primop_e (bb , offset_of_op , singleton (type_decl_ref (a , (TypeDeclRef ) { .decl = global_struct_t })), singleton (size_t_literal (a , i )));
394
419
// const Node* offset_in_words = bytes_to_words(bb, offset);
395
- cnst -> payload .constant .instruction = yield_values_and_wrap_in_block (bb , singleton (offset ));
396
-
397
- register_processed (& ctx -> rewriter , decl , cnst );
420
+ new_address -> payload .constant .instruction = yield_values_and_wrap_in_block (bb , singleton (offset ));
398
421
399
- members_count ++ ;
400
- }
401
-
402
- // add some dummy thing so we don't end up with a zero-sized thing, which SPIR-V hates
403
- if (members_count == 0 ) {
404
- member_tys [0 ] = int32_type (a );
405
- member_names [0 ] = "dummy" ;
406
- members_count ++ ;
422
+ register_processed (& ctx -> rewriter , decl , new_address );
407
423
}
408
424
409
425
const Type * record_t = record_type (a , (RecordType ) {
410
- .members = nodes (a , members_count , member_tys ),
411
- .names = strings (a , members_count , member_names )
426
+ .members = nodes (a , collected . count , member_tys ),
427
+ .names = strings (a , collected . count , member_names )
412
428
});
413
429
414
430
//return record_t;
415
431
global_struct_t -> payload .nom_type .body = record_t ;
432
+ return global_struct_t ;
433
+ }
434
+
435
+ static void store_init_data (Context * ctx , AddressSpace as , Nodes collected , BodyBuilder * bb ) {
436
+ IrArena * oa = ctx -> rewriter .src_arena ;
437
+ IrArena * a = ctx -> rewriter .dst_arena ;
438
+ for (size_t i = 0 ; i < collected .count ; i ++ ) {
439
+ const Node * old_decl = collected .nodes [i ];
440
+ assert (old_decl -> tag == GlobalVariable_TAG );
441
+ const Node * old_init = old_decl -> payload .global_variable .init ;
442
+ if (old_init ) {
443
+ const Node * old_store = prim_op_helper (oa , store_op , empty (oa ), mk_nodes (oa , ref_decl_helper (oa , old_decl ), old_init ));
444
+ bind_instruction (bb , rewrite_node (& ctx -> rewriter , old_store ));
445
+ }
446
+ }
416
447
}
417
448
418
449
static void construct_emulated_memory_array (Context * ctx , AddressSpace as , AddressSpace logical_as ) {
419
450
IrArena * a = ctx -> rewriter .dst_arena ;
420
451
Module * m = ctx -> rewriter .dst_module ;
421
452
String as_name = get_address_space_name (as );
422
- Nodes annotations = singleton (annotation (a , (Annotation ) { .name = "Generated" }));
423
453
424
- Node * global_struct_t = nominal_type (m , annotations , format_string_arena (a -> arena , "globals_physical_%s_t" , as_name ));
425
- //global_struct_t->payload.nom_type.body = collect_globals_into_record_type(ctx, as);
426
- collect_globals_into_record_type (ctx , global_struct_t , as );
454
+ const Type * word_type = int_type (a , (Int ) { .width = a -> config .memory .word_size , .is_signed = false });
455
+ const Type * ptr_size_type = int_type (a , (Int ) { .width = a -> config .memory .ptr_size , .is_signed = false });
456
+
457
+ ctx -> collected [as ] = collect_globals (ctx , as );
458
+ if (ctx -> collected [as ].count == 0 ) {
459
+ const Type * words_array_type = arr_type (a , (ArrType ) {
460
+ .element_type = word_type ,
461
+ .size = NULL
462
+ });
463
+ * get_emulated_as_word_array (ctx , as ) = undef (a , (Undef ) { .type = ptr_type (a , (PtrType ) { .address_space = logical_as , .pointed_type = words_array_type }) });
464
+ return ;
465
+ }
466
+
467
+ const Node * global_struct_t = make_record_type (ctx , as , ctx -> collected [as ]);
468
+
469
+ Nodes annotations = singleton (annotation (a , (Annotation ) { .name = "Generated" }));
427
470
428
471
// compute the size
429
472
BodyBuilder * bb = begin_body (a );
430
473
const Node * size_of = gen_primop_e (bb , size_of_op , singleton (type_decl_ref (a , (TypeDeclRef ) { .decl = global_struct_t })), empty (a ));
431
474
const Node * size_in_words = bytes_to_words (bb , size_of );
432
475
433
- const Type * word_type = int_type (a , (Int ) { .width = a -> config .memory .word_size , .is_signed = false });
434
- const Type * ptr_size_type = int_type (a , (Int ) { .width = a -> config .memory .ptr_size , .is_signed = false });
435
-
436
476
Node * constant_decl = constant (m , annotations , ptr_size_type , format_string_interned (a , "globals_physical_%s_size" , as_name ));
437
477
constant_decl -> payload .constant .instruction = yield_values_and_wrap_in_block (bb , singleton (size_in_words ));
438
478
@@ -442,7 +482,8 @@ static void construct_emulated_memory_array(Context* ctx, AddressSpace as, Addre
442
482
});
443
483
444
484
Node * words_array = global_var (m , annotations , words_array_type , format_string_arena (a -> arena , "addressable_word_memory_%s" , as_name ), logical_as );
445
- * get_emulated_as_word_array (ctx , as ) = words_array ;
485
+
486
+ * get_emulated_as_word_array (ctx , as ) = ref_decl_helper (a , words_array );
446
487
}
447
488
448
489
Module * lower_physical_ptrs (const CompilerConfig * config , Module * src ) {
0 commit comments