Skip to content

Commit eb37f77

Browse files
committed
lift_indirect_targets: compute the CFG where the liftee is the entry and use that to compute free variables
1 parent b5552d1 commit eb37f77

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/shady/passes/lift_indirect_targets.c

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ typedef struct Context_ {
2727
Rewriter rewriter;
2828
CFG* cfg;
2929
const UsesMap* uses;
30-
struct Dict* live_vars;
3130

3231
struct Dict* lifted;
3332
bool disable_lowering;
@@ -83,27 +82,32 @@ static void add_to_recover_context(struct List* recover_context, struct Dict* se
8382
}
8483
}
8584

86-
static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name) {
87-
assert(is_basic_block(cont) || is_case(cont));
88-
LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, cont);
85+
static LiftedCont* lambda_lift(Context* ctx, const Node* liftee, String given_name) {
86+
assert(is_basic_block(liftee) || is_case(liftee));
87+
LiftedCont** found = find_value_dict(const Node*, LiftedCont*, ctx->lifted, liftee);
8988
if (found)
9089
return *found;
9190

9291
IrArena* a = ctx->rewriter.dst_arena;
93-
Nodes oparams = get_abstraction_params(cont);
94-
const Node* obody = get_abstraction_body(cont);
92+
Nodes oparams = get_abstraction_params(liftee);
93+
const Node* obody = get_abstraction_body(liftee);
9594

96-
String name = is_basic_block(cont) ? format_string_arena(a->arena, "%s_%s", get_abstraction_name(cont->payload.basic_block.fn), get_abstraction_name(cont)) : unique_name(a, given_name);
95+
String name = is_basic_block(liftee) ? format_string_arena(a->arena, "%s_%s", get_abstraction_name(liftee->payload.basic_block.fn), get_abstraction_name(liftee)) : unique_name(a, given_name);
9796

9897
// Compute the live stuff we'll need
99-
CFNode* cf_node = cfg_lookup(ctx->cfg, cont);
100-
CFNodeVariables* node_vars = *find_value_dict(CFNode*, CFNodeVariables*, ctx->live_vars, cf_node);
98+
CFG* cfg_rooted_in_liftee = build_cfg(ctx->cfg->entry->node, liftee, NULL, false);
99+
CFNode* cf_node = cfg_lookup(cfg_rooted_in_liftee, liftee);
100+
struct Dict* live_vars = compute_cfg_variables_map(cfg_rooted_in_liftee);
101+
CFNodeVariables* node_vars = *find_value_dict(CFNode*, CFNodeVariables*, live_vars, cf_node);
101102
struct List* recover_context = new_list(const Node*);
102103

103-
add_to_recover_context(recover_context, node_vars->bound_set, cont);
104+
add_to_recover_context(recover_context, node_vars->free_set, liftee);
104105
size_t recover_context_size = entries_count_list(recover_context);
105106

106-
debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", name, recover_context_size);
107+
destroy_cfg_variables_map(live_vars);
108+
destroy_cfg(cfg_rooted_in_liftee);
109+
110+
debugv_print("lambda_lift: free (to-be-spilled) variables at '%s' (count=%d): ", get_abstraction_name_safe(liftee), recover_context_size);
107111
for (size_t i = 0; i < recover_context_size; i++) {
108112
const Node* item = read_list(const Node*, recover_context)[i];
109113
debugv_print("%s %%%d", get_value_name(item) ? get_value_name(item) : "", item->id);
@@ -116,9 +120,9 @@ static LiftedCont* lambda_lift(Context* ctx, const Node* cont, String given_name
116120
Nodes new_params = recreate_variables(&ctx->rewriter, oparams);
117121

118122
LiftedCont* lifted_cont = calloc(sizeof(LiftedCont), 1);
119-
lifted_cont->old_cont = cont;
123+
lifted_cont->old_cont = liftee;
120124
lifted_cont->save_values = recover_context;
121-
insert_dict(const Node*, LiftedCont*, ctx->lifted, cont, lifted_cont);
125+
insert_dict(const Node*, LiftedCont*, ctx->lifted, liftee, lifted_cont);
122126

123127
Context lifting_ctx = *ctx;
124128
lifting_ctx.rewriter = create_children_rewriter(&ctx->rewriter);
@@ -175,15 +179,13 @@ static const Node* process_node(Context* ctx, const Node* node) {
175179
Context fn_ctx = *ctx;
176180
fn_ctx.cfg = build_fn_cfg(node);
177181
fn_ctx.uses = create_uses_map(node, (NcDeclaration | NcType));
178-
fn_ctx.live_vars = compute_cfg_variables_map(fn_ctx.cfg);
179182
fn_ctx.disable_lowering = lookup_annotation(node, "Internal");
180183
ctx = &fn_ctx;
181184

182185
Node* new = recreate_decl_header_identity(&ctx->rewriter, node);
183186
recreate_decl_body_identity(&ctx->rewriter, node, new);
184187

185188
destroy_uses_map(ctx->uses);
186-
destroy_cfg_variables_map(ctx->live_vars);
187189
destroy_cfg(ctx->cfg);
188190
return new;
189191
}
@@ -205,7 +207,7 @@ static const Node* process_node(Context* ctx, const Node* node) {
205207

206208
const Node* otail = get_let_tail(node);
207209
BodyBuilder* bb = begin_body(a);
208-
LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "post_control_%s", get_abstraction_name(ctx->cfg->entry->node))));
210+
LiftedCont* lifted_tail = lambda_lift(ctx, otail, unique_name(a, format_string_arena(a->arena, "lifted %s", get_abstraction_name_safe(otail))));
209211
const Node* sp = add_spill_instrs(ctx, bb, lifted_tail->save_values);
210212
const Node* tail_ptr = fn_addr_helper(a, lifted_tail->lifted_fn);
211213

@@ -253,7 +255,7 @@ Module* lift_indirect_targets(const CompilerConfig* config, Module* src) {
253255
}
254256
destroy_dict(ctx.lifted);
255257
destroy_rewriter(&ctx.rewriter);
256-
log_module(DEBUGVV, config, dst);
258+
// log_module(DEBUGVV, config, dst);
257259
verify_module(config, dst);
258260
src = dst;
259261
if (oa)

0 commit comments

Comments
 (0)