15
15
#include "../analysis/free_variables.h"
16
16
#include "../analysis/uses.h"
17
17
#include "../analysis/leak.h"
18
+ #include "../analysis/verify.h"
18
19
19
20
#include <assert.h>
20
21
#include <string.h>
@@ -31,6 +32,8 @@ typedef struct Context_ {
31
32
struct Dict * lifted ;
32
33
bool disable_lowering ;
33
34
const CompilerConfig * config ;
35
+
36
+ bool * todo ;
34
37
} Context ;
35
38
36
39
static const Node * process_node (Context * ctx , const Node * node );
@@ -66,6 +69,20 @@ static const Node* add_spill_instrs(Context* ctx, BodyBuilder* builder, struct L
66
69
return sp ;
67
70
}
68
71
72
+ static void add_to_recover_context (struct List * recover_context , struct Dict * set , const Node * except ) {
73
+ Nodes params = get_abstraction_params (except );
74
+ size_t i = 0 ;
75
+ const Node * item ;
76
+ while (dict_iter (set , & i , & item , NULL )) {
77
+ for (size_t j = 0 ; j < params .count ; j ++ ) {
78
+ if (item == params .nodes [j ])
79
+ goto skip ;
80
+ }
81
+ append_list (const Node * , recover_context , item );
82
+ skip :;
83
+ }
84
+ }
85
+
69
86
static LiftedCont * lambda_lift (Context * ctx , const Node * cont , String given_name ) {
70
87
assert (is_basic_block (cont ) || is_case (cont ));
71
88
LiftedCont * * found = find_value_dict (const Node * , LiftedCont * , ctx -> lifted , cont );
@@ -82,20 +99,19 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
82
99
CFNode * cf_node = scope_lookup (ctx -> scope , cont );
83
100
CFNodeVariables * node_vars = * find_value_dict (CFNode * , CFNodeVariables * , ctx -> scope_vars , cf_node );
84
101
struct List * recover_context = new_list (const Node * );
85
- size_t recover_context_size = entries_count_dict (node_vars -> free_set );
86
-
87
- {
88
- debugv_print ("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): " , name , entries_count_dict (node_vars -> free_set ));
89
- size_t i = 0 ;
90
- const Node * item ;
91
- while (dict_iter (node_vars -> free_set , & i , & item , NULL )) {
92
- append_list (const Node * , recover_context , item );
93
- debugv_print (get_value_name_safe (item ));
94
- if (i + 1 < recover_context_size )
95
- debugv_print (", " );
96
- }
97
- debugv_print ("\n" );
102
+
103
+ // add_to_recover_context(recover_context, node_vars->free_set, cont);
104
+ add_to_recover_context (recover_context , node_vars -> bound_set , cont );
105
+ size_t recover_context_size = entries_count_list (recover_context );
106
+
107
+ debugv_print ("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): " , name , recover_context_size );
108
+ for (size_t i = 0 ; i < recover_context_size ; i ++ ) {
109
+ const Node * item = read_list (const Node * , recover_context )[i ];
110
+ debugv_print (get_value_name_safe (item ));
111
+ if (i + 1 < recover_context_size )
112
+ debugv_print (", " );
98
113
}
114
+ debugv_print ("\n" );
99
115
100
116
// Create and register new parameters for the lifted continuation
101
117
Nodes new_params = recreate_variables (& ctx -> rewriter , oparams );
@@ -106,7 +122,13 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
106
122
insert_dict (const Node * , LiftedCont * , ctx -> lifted , cont , lifted_cont );
107
123
108
124
Context lifting_ctx = * ctx ;
109
- lifting_ctx .rewriter = create_rewriter (ctx -> rewriter .src_module , ctx -> rewriter .dst_module , (RewriteNodeFn ) process_node );
125
+ // struct Dict* old_map = lifting_ctx.rewriter.map;
126
+ // lifting_ctx.rewriter.map = clone_dict(lifting_ctx.rewriter.map);
127
+
128
+ // lifting_ctx.rewriter = create_rewriter(ctx->rewriter.src_module, ctx->rewriter.dst_module, (RewriteNodeFn) process_node);
129
+ // lifting_ctx.rewriter.decls_map = NULL;
130
+ lifting_ctx .rewriter .map = new_dict (const Node * , Node * , (HashFn ) hash_node , (CmpFn ) compare_node );
131
+ lifting_ctx .rewriter .parent = & ctx -> rewriter ;
110
132
register_processed_list (& lifting_ctx .rewriter , oparams , new_params );
111
133
112
134
const Node * payload = var (a , qualified_type_helper (uint32_type (a ), false), "sp" );
@@ -140,6 +162,7 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
140
162
const Node * substituted = rewrite_node (& lifting_ctx .rewriter , obody );
141
163
//destroy_dict(lifting_ctx.rewriter.processed);
142
164
destroy_rewriter (& lifting_ctx .rewriter );
165
+ // lifting_ctx.rewriter.map = old_map;
143
166
144
167
assert (is_terminator (substituted ));
145
168
new_fn -> payload .fun .body = finish_body (bb , substituted );
@@ -151,27 +174,18 @@ static const Node* process_node(Context* ctx, const Node* node) {
151
174
const Node * found = search_processed (& ctx -> rewriter , node );
152
175
if (found ) return found ;
153
176
154
- // TODO: share this code
155
- if (is_declaration (node )) {
156
- String name = get_declaration_name (node );
157
- Nodes decls = get_module_declarations (ctx -> rewriter .dst_module );
158
- for (size_t i = 0 ; i < decls .count ; i ++ ) {
159
- if (strcmp (get_declaration_name (decls .nodes [i ]), name ) == 0 )
160
- return decls .nodes [i ];
161
- }
162
- }
163
-
164
177
IrArena * a = ctx -> rewriter .dst_arena ;
165
178
166
- if (ctx -> disable_lowering )
167
- return recreate_node_identity (& ctx -> rewriter , node );
168
-
169
- switch (node -> tag ) {
179
+ switch (is_declaration (node )) {
170
180
case Function_TAG : {
181
+ while (ctx -> rewriter .parent )
182
+ ctx = (Context * ) ctx -> rewriter .parent ;
183
+
171
184
Context fn_ctx = * ctx ;
172
185
fn_ctx .scope = new_scope (node );
173
186
fn_ctx .scope_uses = create_uses_map (node , (NcDeclaration | NcType ));
174
187
fn_ctx .scope_vars = compute_scope_variables_map (fn_ctx .scope );
188
+ fn_ctx .disable_lowering = lookup_annotation (node , "Internal" );
175
189
ctx = & fn_ctx ;
176
190
177
191
Node * new = recreate_decl_header_identity (& ctx -> rewriter , node );
@@ -182,51 +196,88 @@ static const Node* process_node(Context* ctx, const Node* node) {
182
196
destroy_scope (ctx -> scope );
183
197
return new ;
184
198
}
199
+ default :
200
+ break ;
201
+ }
202
+
203
+ if (ctx -> disable_lowering )
204
+ return recreate_node_identity (& ctx -> rewriter , node );
205
+
206
+ switch (node -> tag ) {
185
207
case Let_TAG : {
186
208
const Node * oinstruction = get_let_instruction (node );
187
209
if (oinstruction -> tag == Control_TAG ) {
188
210
const Node * oinside = oinstruction -> payload .control .inside ;
189
211
assert (is_case (oinside ));
190
212
if (!is_control_static (ctx -> scope_uses , oinstruction ) || ctx -> config -> hacks .force_join_point_lifting ) {
213
+ * ctx -> todo = true;
214
+
191
215
const Node * otail = get_let_tail (node );
192
216
BodyBuilder * bb = begin_body (a );
193
217
LiftedCont * lifted_tail = lambda_lift (ctx , otail , unique_name (a , format_string_arena (a -> arena , "post_control_%s" , get_abstraction_name (ctx -> scope -> entry -> node ))));
194
218
const Node * sp = add_spill_instrs (ctx , bb , lifted_tail -> save_values );
195
219
const Node * tail_ptr = fn_addr_helper (a , lifted_tail -> lifted_fn );
196
220
197
221
const Node * jp = gen_primop_e (bb , create_joint_point_op , rewrite_nodes (& ctx -> rewriter , oinstruction -> payload .control .yield_types ), mk_nodes (a , tail_ptr , sp ));
222
+ // dumbass hack
223
+ jp = gen_primop_e (bb , subgroup_assume_uniform_op , empty (a ), singleton (jp ));
198
224
199
225
return finish_body (bb , let (a , quote_helper (a , singleton (jp )), rewrite_node (& ctx -> rewriter , oinside )));
200
226
}
201
227
}
202
-
203
- return recreate_node_identity (& ctx -> rewriter , node );
228
+ break ;
204
229
}
205
- default : return recreate_node_identity ( & ctx -> rewriter , node ) ;
230
+ default : break ;
206
231
}
232
+ return recreate_node_identity (& ctx -> rewriter , node );
207
233
}
208
234
209
235
Module * lift_indirect_targets (const CompilerConfig * config , Module * src ) {
210
236
ArenaConfig aconfig = get_arena_config (get_module_arena (src ));
237
+ IrArena * a = NULL ;
238
+ Module * dst ;
239
+
240
+ int round = 0 ;
241
+ while (true) {
242
+ debugv_print ("lift_indirect_target: round %d\n" , round ++ );
243
+ IrArena * oa = a ;
244
+ a = new_ir_arena (aconfig );
245
+ dst = new_module (a , get_module_name (src ));
246
+ bool todo = false;
247
+ Context ctx = {
248
+ .rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
249
+ .lifted = new_dict (const Node * , LiftedCont * , (HashFn ) hash_node , (CmpFn ) compare_node ),
250
+ .config = config ,
251
+
252
+ .todo = & todo
253
+ };
254
+
255
+ rewrite_module (& ctx .rewriter );
256
+
257
+ size_t iter = 0 ;
258
+ LiftedCont * lifted_cont ;
259
+ while (dict_iter (ctx .lifted , & iter , NULL , & lifted_cont )) {
260
+ destroy_list (lifted_cont -> save_values );
261
+ free (lifted_cont );
262
+ }
263
+ destroy_dict (ctx .lifted );
264
+ destroy_rewriter (& ctx .rewriter );
265
+ log_module (DEBUGVV , config , dst );
266
+ verify_module (config , dst );
267
+ src = dst ;
268
+ if (oa )
269
+ destroy_ir_arena (oa );
270
+ if (!todo ) {
271
+ break ;
272
+ }
273
+ }
274
+
211
275
// this will be safe now since we won't lift any more code after this pass
212
276
aconfig .optimisations .weaken_non_leaking_allocas = true;
213
- IrArena * a = new_ir_arena (aconfig );
214
- Module * dst = new_module (a , get_module_name (src ));
215
- Context ctx = {
216
- .rewriter = create_rewriter (src , dst , (RewriteNodeFn ) process_node ),
217
- .lifted = new_dict (const Node * , LiftedCont * , (HashFn ) hash_node , (CmpFn ) compare_node ),
218
- .config = config ,
219
- };
220
-
221
- rewrite_module (& ctx .rewriter );
222
-
223
- size_t iter = 0 ;
224
- LiftedCont * lifted_cont ;
225
- while (dict_iter (ctx .lifted , & iter , NULL , & lifted_cont )) {
226
- destroy_list (lifted_cont -> save_values );
227
- free (lifted_cont );
228
- }
229
- destroy_dict (ctx .lifted );
230
- destroy_rewriter (& ctx .rewriter );
277
+ IrArena * a2 = new_ir_arena (aconfig );
278
+ dst = new_module (a2 , get_module_name (src ));
279
+ Rewriter r = create_importer (src , dst );
280
+ rewrite_module (& r );
281
+ destroy_ir_arena (a );
231
282
return dst ;
232
283
}
0 commit comments