Skip to content

Commit 853af24

Browse files
committed
fixed lift_indirect_targets
1 parent ab4bb8d commit 853af24

File tree

3 files changed

+127
-55
lines changed

3 files changed

+127
-55
lines changed

src/shady/passes/lift_indirect_targets.c

Lines changed: 100 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "../analysis/free_variables.h"
1616
#include "../analysis/uses.h"
1717
#include "../analysis/leak.h"
18+
#include "../analysis/verify.h"
1819

1920
#include <assert.h>
2021
#include <string.h>
@@ -31,6 +32,8 @@ typedef struct Context_ {
3132
struct Dict* lifted;
3233
bool disable_lowering;
3334
const CompilerConfig* config;
35+
36+
bool* todo;
3437
} Context;
3538

3639
static const Node* process_node(Context* ctx, const Node* node);
@@ -66,6 +69,20 @@ static const Node* add_spill_instrs(Context* ctx, BodyBuilder* builder, struct L
6669
return sp;
6770
}
6871

72+
static void add_to_recover_context(struct List* recover_context, struct Dict* set, const Node* except) {
73+
Nodes params = get_abstraction_params(except);
74+
size_t i = 0;
75+
const Node* item;
76+
while (dict_iter(set, &i, &item, NULL)) {
77+
for (size_t j = 0; j < params.count; j++) {
78+
if (item == params.nodes[j])
79+
goto skip;
80+
}
81+
append_list(const Node*, recover_context, item );
82+
skip:;
83+
}
84+
}
85+
6986
static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name) {
7087
assert(is_basic_block(cont) || is_case(cont));
7188
LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, cont);
@@ -82,20 +99,19 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
8299
CFNode* cf_node = scope_lookup(ctx->scope, cont);
83100
CFNodeVariables* node_vars = *find_value_dict(CFNode*, CFNodeVariables*, ctx->scope_vars, cf_node);
84101
struct List* recover_context = new_list(const Node*);
85-
size_t recover_context_size = entries_count_dict(node_vars->free_set);
86-
87-
{
88-
debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", name, entries_count_dict(node_vars->free_set));
89-
size_t i = 0;
90-
const Node* item;
91-
while (dict_iter(node_vars->free_set, &i, &item, NULL)) {
92-
append_list(const Node*, recover_context, item );
93-
debugv_print(get_value_name_safe(item));
94-
if (i + 1 < recover_context_size)
95-
debugv_print(", ");
96-
}
97-
debugv_print("\n");
102+
103+
// add_to_recover_context(recover_context, node_vars->free_set, cont);
104+
add_to_recover_context(recover_context, node_vars->bound_set, cont);
105+
size_t recover_context_size = entries_count_list(recover_context);
106+
107+
debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", name, recover_context_size);
108+
for (size_t i = 0; i < recover_context_size; i++) {
109+
const Node* item = read_list(const Node*, recover_context)[i];
110+
debugv_print(get_value_name_safe(item));
111+
if (i + 1 < recover_context_size)
112+
debugv_print(", ");
98113
}
114+
debugv_print("\n");
99115

100116
// Create and register new parameters for the lifted continuation
101117
Nodes new_params = recreate_variables(&ctx->rewriter, oparams);
@@ -106,7 +122,13 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
106122
insert_dict(const Node*, LiftedCont*, ctx->lifted, cont, lifted_cont);
107123

108124
Context lifting_ctx = *ctx;
109-
lifting_ctx.rewriter = create_rewriter(ctx->rewriter.src_module, ctx->rewriter.dst_module, (RewriteNodeFn) process_node);
125+
// struct Dict* old_map = lifting_ctx.rewriter.map;
126+
// lifting_ctx.rewriter.map = clone_dict(lifting_ctx.rewriter.map);
127+
128+
// lifting_ctx.rewriter = create_rewriter(ctx->rewriter.src_module, ctx->rewriter.dst_module, (RewriteNodeFn) process_node);
129+
// lifting_ctx.rewriter.decls_map = NULL;
130+
lifting_ctx.rewriter.map = new_dict(const Node*, Node*, (HashFn) hash_node, (CmpFn) compare_node);
131+
lifting_ctx.rewriter.parent = &ctx->rewriter;
110132
register_processed_list(&lifting_ctx.rewriter, oparams, new_params);
111133

112134
const Node* payload = var(a, qualified_type_helper(uint32_type(a), false), "sp");
@@ -140,6 +162,7 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
140162
const Node* substituted = rewrite_node(&lifting_ctx.rewriter, obody);
141163
//destroy_dict(lifting_ctx.rewriter.processed);
142164
destroy_rewriter(&lifting_ctx.rewriter);
165+
// lifting_ctx.rewriter.map = old_map;
143166

144167
assert(is_terminator(substituted));
145168
new_fn->payload.fun.body = finish_body(bb, substituted);
@@ -151,27 +174,18 @@ static const Node* process_node(Context* ctx, const Node* node) {
151174
const Node* found = search_processed(&ctx->rewriter, node);
152175
if (found) return found;
153176

154-
// TODO: share this code
155-
if (is_declaration(node)) {
156-
String name = get_declaration_name(node);
157-
Nodes decls = get_module_declarations(ctx->rewriter.dst_module);
158-
for (size_t i = 0; i < decls.count; i++) {
159-
if (strcmp(get_declaration_name(decls.nodes[i]), name) == 0)
160-
return decls.nodes[i];
161-
}
162-
}
163-
164177
IrArena* a = ctx->rewriter.dst_arena;
165178

166-
if (ctx->disable_lowering)
167-
return recreate_node_identity(&ctx->rewriter, node);
168-
169-
switch (node->tag) {
179+
switch (is_declaration(node)) {
170180
case Function_TAG: {
181+
while (ctx->rewriter.parent)
182+
ctx = (Context*) ctx->rewriter.parent;
183+
171184
Context fn_ctx = *ctx;
172185
fn_ctx.scope = new_scope(node);
173186
fn_ctx.scope_uses = create_uses_map(node, (NcDeclaration | NcType));
174187
fn_ctx.scope_vars = compute_scope_variables_map(fn_ctx.scope);
188+
fn_ctx.disable_lowering = lookup_annotation(node, "Internal");
175189
ctx = &fn_ctx;
176190

177191
Node* new = recreate_decl_header_identity(&ctx->rewriter, node);
@@ -182,51 +196,88 @@ static const Node* process_node(Context* ctx, const Node* node) {
182196
destroy_scope(ctx->scope);
183197
return new;
184198
}
199+
default:
200+
break;
201+
}
202+
203+
if (ctx->disable_lowering)
204+
return recreate_node_identity(&ctx->rewriter, node);
205+
206+
switch (node->tag) {
185207
case Let_TAG: {
186208
const Node* oinstruction = get_let_instruction(node);
187209
if (oinstruction->tag == Control_TAG) {
188210
const Node* oinside = oinstruction->payload.control.inside;
189211
assert(is_case(oinside));
190212
if (!is_control_static(ctx->scope_uses, oinstruction) || ctx->config->hacks.force_join_point_lifting) {
213+
*ctx->todo = true;
214+
191215
const Node* otail = get_let_tail(node);
192216
BodyBuilder* bb = begin_body(a);
193217
LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "post_control_%s", get_abstraction_name(ctx->scope->entry->node))));
194218
const Node* sp = add_spill_instrs(ctx, bb, lifted_tail->save_values);
195219
const Node* tail_ptr = fn_addr_helper(a, lifted_tail->lifted_fn);
196220

197221
const Node* jp = gen_primop_e(bb, create_joint_point_op, rewrite_nodes(&ctx->rewriter, oinstruction->payload.control.yield_types), mk_nodes(a, tail_ptr, sp));
222+
// dumbass hack
223+
jp = gen_primop_e(bb, subgroup_assume_uniform_op, empty(a), singleton(jp));
198224

199225
return finish_body(bb, let(a, quote_helper(a, singleton(jp)), rewrite_node(&ctx->rewriter, oinside)));
200226
}
201227
}
202-
203-
return recreate_node_identity(&ctx->rewriter, node);
228+
break;
204229
}
205-
default: return recreate_node_identity(&ctx->rewriter, node);
230+
default: break;
206231
}
232+
return recreate_node_identity(&ctx->rewriter, node);
207233
}
208234

209235
Module* lift_indirect_targets(const CompilerConfig* config, Module* src) {
210236
ArenaConfig aconfig = get_arena_config(get_module_arena(src));
237+
IrArena* a = NULL;
238+
Module* dst;
239+
240+
int round = 0;
241+
while (true) {
242+
debugv_print("lift_indirect_target: round %d\n", round++);
243+
IrArena* oa = a;
244+
a = new_ir_arena(aconfig);
245+
dst = new_module(a, get_module_name(src));
246+
bool todo = false;
247+
Context ctx = {
248+
.rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node),
249+
.lifted = new_dict(const Node*, LiftedCont*, (HashFn) hash_node, (CmpFn) compare_node),
250+
.config = config,
251+
252+
.todo = &todo
253+
};
254+
255+
rewrite_module(&ctx.rewriter);
256+
257+
size_t iter = 0;
258+
LiftedCont* lifted_cont;
259+
while (dict_iter(ctx.lifted, &iter, NULL, &lifted_cont)) {
260+
destroy_list(lifted_cont->save_values);
261+
free(lifted_cont);
262+
}
263+
destroy_dict(ctx.lifted);
264+
destroy_rewriter(&ctx.rewriter);
265+
log_module(DEBUGVV, config, dst);
266+
verify_module(config, dst);
267+
src = dst;
268+
if (oa)
269+
destroy_ir_arena(oa);
270+
if (!todo) {
271+
break;
272+
}
273+
}
274+
211275
// this will be safe now since we won't lift any more code after this pass
212276
aconfig.optimisations.weaken_non_leaking_allocas = true;
213-
IrArena* a = new_ir_arena(aconfig);
214-
Module* dst = new_module(a, get_module_name(src));
215-
Context ctx = {
216-
.rewriter = create_rewriter(src, dst, (RewriteNodeFn) process_node),
217-
.lifted = new_dict(const Node*, LiftedCont*, (HashFn) hash_node, (CmpFn) compare_node),
218-
.config = config,
219-
};
220-
221-
rewrite_module(&ctx.rewriter);
222-
223-
size_t iter = 0;
224-
LiftedCont* lifted_cont;
225-
while (dict_iter(ctx.lifted, &iter, NULL, &lifted_cont)) {
226-
destroy_list(lifted_cont->save_values);
227-
free(lifted_cont);
228-
}
229-
destroy_dict(ctx.lifted);
230-
destroy_rewriter(&ctx.rewriter);
277+
IrArena* a2 = new_ir_arena(aconfig);
278+
dst = new_module(a2, get_module_name(src));
279+
Rewriter r = create_importer(src, dst);
280+
rewrite_module(&r);
281+
destroy_ir_arena(a);
231282
return dst;
232283
}

src/shady/rewrite.c

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ Rewriter create_rewriter(Module* src, Module* dst, RewriteNodeFn fn) {
3333
void destroy_rewriter(Rewriter* r) {
3434
assert(r->map);
3535
destroy_dict(r->map);
36-
destroy_dict(r->decls_map);
36+
if (!r->parent)
37+
destroy_dict(r->decls_map);
3738
}
3839

3940
Rewriter create_importer(Module* src, Module* dst) {
@@ -137,11 +138,27 @@ static Nodes rewrite_ops_helper(Rewriter* rewriter, NodeClass class, String op_n
137138
return rewrite_nodes_with_fn(rewriter, old_nodes, rewriter->rewrite_fn);
138139
}
139140

141+
static const Node* search_processed_(const Rewriter* ctx, const Node* old, bool deep) {
142+
if (is_declaration(old)) {
143+
const Node** found = find_value_dict(const Node*, const Node*, ctx->decls_map, old);
144+
return found ? *found : NULL;
145+
}
146+
147+
while (ctx) {
148+
assert(ctx->map && "this rewriter has no processed cache");
149+
const Node** found = find_value_dict(const Node*, const Node*, ctx->map, old);
150+
if (found)
151+
return *found;
152+
if (deep)
153+
ctx = ctx->parent;
154+
else
155+
ctx = NULL;
156+
}
157+
return NULL;
158+
}
159+
140160
const Node* search_processed(const Rewriter* ctx, const Node* old) {
141-
struct Dict* map = is_declaration(old) ? ctx->decls_map : ctx->map;
142-
assert(map && "this rewriter has no processed cache");
143-
const Node** found = find_value_dict(const Node*, const Node*, map, old);
144-
return found ? *found : NULL;
161+
return search_processed_(ctx, old, false);
145162
}
146163

147164
const Node* find_processed(const Rewriter* ctx, const Node* old) {
@@ -154,7 +171,7 @@ void register_processed(Rewriter* ctx, const Node* old, const Node* new) {
154171
assert(old->arena == ctx->src_arena);
155172
assert(new->arena == ctx->dst_arena);
156173
#ifndef NDEBUG
157-
const Node* found = search_processed(ctx, old);
174+
const Node* found = search_processed_(ctx, old, false);
158175
if (found) {
159176
error_print("Trying to replace ");
160177
log_node(ERROR, old);
@@ -190,6 +207,7 @@ bool compare_node(Node**, Node**);
190207
#include "rewrite_generated.c"
191208

192209
void rewrite_module(Rewriter* rewriter) {
210+
assert(rewriter->dst_module != rewriter->src_module);
193211
Nodes old_decls = get_module_declarations(rewriter->src_module);
194212
for (size_t i = 0; i < old_decls.count; i++) {
195213
if (old_decls.nodes[i]->tag == NominalType_TAG) continue;

src/shady/rewrite.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct Rewriter_ {
3737
bool fold_quote;
3838
bool process_variables;
3939
} config;
40+
41+
Rewriter* parent;
42+
4043
struct Dict* map;
4144
struct Dict* decls_map;
4245
};

0 commit comments

Comments
 (0)