Skip to content

Commit 0d763be

Browse files
committed
rework infer pass
1 parent 76602a2 commit 0d763be

File tree

2 files changed

+88
-75
lines changed

2 files changed

+88
-75
lines changed

src/shady/passes/infer.c

Lines changed: 85 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ typedef struct {
3636
const Nodes* continue_types;
3737
} Context;
3838

39+
static const Node* infer_value(Context* ctx, const Node* node, const Type* expected_type);
40+
static const Node* infer_instruction(Context* ctx, const Node* node, const Nodes* expected_types);
41+
3942
static const Node* infer(Context* ctx, const Node* node, const Type* expect) {
4043
Context ctx2 = *ctx;
4144
ctx2.expected_type = expect;
@@ -51,7 +54,7 @@ static Nodes infer_nodes(Context* ctx, Nodes nodes) {
5154
#define rewrite_node error("don't use this directly, use the 'infer' and 'infer_node' helpers")
5255
#define rewrite_nodes rewrite_node
5356

54-
static const Node* _infer_annotation(Context* ctx, const Node* node) {
57+
static const Node* infer_annotation(Context* ctx, const Node* node) {
5558
IrArena* a = ctx->rewriter.dst_arena;
5659
assert(is_annotation(node));
5760
switch (node->tag) {
@@ -63,7 +66,7 @@ static const Node* _infer_annotation(Context* ctx, const Node* node) {
6366
}
6467
}
6568

66-
static const Node* _infer_type(Context* ctx, const Type* type) {
69+
static const Node* infer_type(Context* ctx, const Type* type) {
6770
IrArena* a = ctx->rewriter.dst_arena;
6871
switch (type->tag) {
6972
case ArrType_TAG: {
@@ -83,7 +86,7 @@ static const Node* _infer_type(Context* ctx, const Type* type) {
8386
}
8487
}
8588

86-
static const Node* _infer_decl(Context* ctx, const Node* node) {
89+
static const Node* infer_decl(Context* ctx, const Node* node) {
8790
assert(is_declaration(node));
8891
const Node* already_done = search_processed(&ctx->rewriter, node);
8992
if (already_done)
@@ -117,9 +120,10 @@ static const Node* _infer_decl(Context* ctx, const Node* node) {
117120
const Node* instruction;
118121
if (imported_hint) {
119122
assert(is_data_type(imported_hint));
120-
instruction = infer(ctx, oconstant->instruction, qualified_type_helper(imported_hint, true));
123+
Nodes s = singleton(qualified_type_helper(imported_hint, true));
124+
instruction = infer_instruction(ctx, oconstant->instruction, &s);
121125
} else {
122-
instruction = infer(ctx, oconstant->instruction, NULL);
126+
instruction = infer_instruction(ctx, oconstant->instruction, NULL);
123127
}
124128
imported_hint = get_unqualified_type(instruction->type);
125129

@@ -156,13 +160,9 @@ static const Type* remove_uniformity_qualifier(const Node* type) {
156160
return type;
157161
}
158162

159-
static const Node* _infer_value(Context* ctx, const Node* node, const Type* expected_type) {
163+
static const Node* infer_value(Context* ctx, const Node* node, const Type* expected_type) {
160164
if (!node) return NULL;
161165

162-
if (expected_type) {
163-
assert(is_value_type(expected_type));
164-
}
165-
166166
IrArena* a = ctx->rewriter.dst_arena;
167167
Rewriter* r = &ctx->rewriter;
168168
switch (is_value(node)) {
@@ -302,11 +302,9 @@ static const Node* _infer_value(Context* ctx, const Node* node, const Type* expe
302302
return recreate_node_identity(&ctx->rewriter, node);
303303
}
304304

305-
static const Node* _infer_case(Context* ctx, const Node* node, const Node* expected) {
305+
static const Node* infer_case(Context* ctx, const Node* node, Nodes inferred_arg_type) {
306306
IrArena* a = ctx->rewriter.dst_arena;
307307
assert(is_case(node));
308-
assert(expected);
309-
Nodes inferred_arg_type = unwrap_multiple_yield_types(a, expected);
310308
assert(inferred_arg_type.count == node->payload.case_.params.count || node->payload.case_.params.count == 0);
311309

312310
Context body_context = *ctx;
@@ -319,7 +317,7 @@ static const Node* _infer_case(Context* ctx, const Node* node, const Node* expec
319317
const Variable* old_param = &node->payload.case_.params.nodes[i]->payload.var;
320318
// for the param type: use the inferred one if none is already provided
321319
// if one is provided, check the inferred argument type is a subtype of the param type
322-
const Type* param_type = infer(ctx, old_param->type, NULL);
320+
const Type* param_type = old_param->type ? infer_type(ctx, old_param->type) : NULL;
323321
// and do not use the provided param type if it is an untyped ptr
324322
if (!param_type || param_type->tag != PtrType_TAG || param_type->payload.ptr_type.pointed_type)
325323
param_type = inferred_arg_type.nodes[i];
@@ -409,7 +407,7 @@ static void fix_source_pointer(BodyBuilder* bb, const Node** operand, const Type
409407
}
410408
}
411409

412-
static const Node* _infer_primop(Context* ctx, const Node* node, const Type* expected_type) {
410+
static const Node* infer_primop(Context* ctx, const Node* node, const Nodes* expected_types) {
413411
assert(node->tag == PrimOp_TAG);
414412
IrArena* a = ctx->rewriter.dst_arena;
415413

@@ -586,7 +584,7 @@ static const Node* _infer_primop(Context* ctx, const Node* node, const Type* exp
586584
}
587585
}
588586

589-
static const Node* _infer_indirect_call(Context* ctx, const Node* node, const Type* expected_type) {
587+
static const Node* infer_indirect_call(Context* ctx, const Node* node, const Nodes* expected_types) {
590588
assert(node->tag == Call_TAG);
591589
IrArena* a = ctx->rewriter.dst_arena;
592590

@@ -616,21 +614,21 @@ static const Node* _infer_indirect_call(Context* ctx, const Node* node, const Ty
616614
});
617615
}
618616

619-
static const Node* _infer_if(Context* ctx, const Node* node, const Type* expected_type) {
617+
static const Node* infer_if(Context* ctx, const Node* node, const Nodes* expected_types) {
620618
assert(node->tag == If_TAG);
621619
IrArena* a = ctx->rewriter.dst_arena;
622-
const Node* condition = infer(ctx, node->payload.if_instr.condition, bool_type(a));
620+
const Node* condition = infer(ctx, node->payload.if_instr.condition, qualified_type_helper(bool_type(a), false));
623621

624622
Nodes join_types = infer_nodes(ctx, node->payload.if_instr.yield_types);
625623
Context infer_if_body_ctx = *ctx;
626624
// When we infer the types of the arguments to a call to merge(), they are expected to be varying
627-
Nodes expected_join_types = annotate_all_types(a, join_types, false);
625+
Nodes expected_join_types = add_qualifiers(a, join_types, false);
628626
infer_if_body_ctx.merge_types = &expected_join_types;
629627

630-
const Node* true_body = infer(&infer_if_body_ctx, node->payload.if_instr.if_true, wrap_multiple_yield_types(a, nodes(a, 0, NULL)));
628+
const Node* true_body = infer_case(&infer_if_body_ctx, node->payload.if_instr.if_true, nodes(a, 0, NULL));
631629
// don't allow seeing the variables made available in the true branch
632630
infer_if_body_ctx.rewriter = ctx->rewriter;
633-
const Node* false_body = node->payload.if_instr.if_false ? infer(&infer_if_body_ctx, node->payload.if_instr.if_false, wrap_multiple_yield_types(a, nodes(a, 0, NULL))) : NULL;
631+
const Node* false_body = node->payload.if_instr.if_false ? infer_case(&infer_if_body_ctx, node->payload.if_instr.if_false, nodes(a, 0, NULL)) : NULL;
634632

635633
return if_instr(a, (If) {
636634
.yield_types = join_types,
@@ -640,7 +638,7 @@ static const Node* _infer_if(Context* ctx, const Node* node, const Type* expecte
640638
});
641639
}
642640

643-
static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expected_type) {
641+
static const Node* infer_loop(Context* ctx, const Node* node, const Nodes* expected_types) {
644642
assert(node->tag == Loop_TAG);
645643
IrArena* a = ctx->rewriter.dst_arena;
646644
Context loop_body_ctx = *ctx;
@@ -649,19 +647,21 @@ static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expec
649647
Nodes old_params = get_abstraction_params(old_body);
650648
Nodes old_params_types = get_variables_types(a, old_params);
651649
Nodes new_params_types = infer_nodes(ctx, old_params_types);
650+
new_params_types = annotate_all_types(a, new_params_types, false);
652651

653652
Nodes old_initial_args = node->payload.loop_instr.initial_args;
654653
LARRAY(const Node*, new_initial_args, old_params.count);
655654
for (size_t i = 0; i < old_params.count; i++)
656655
new_initial_args[i] = infer(ctx, old_initial_args.nodes[i], new_params_types.nodes[i]);
657656

658657
Nodes loop_yield_types = infer_nodes(ctx, node->payload.loop_instr.yield_types);
658+
Nodes qual_yield_types = add_qualifiers(a, loop_yield_types, false);
659659

660660
loop_body_ctx.merge_types = NULL;
661-
loop_body_ctx.break_types = &loop_yield_types;
661+
loop_body_ctx.break_types = &qual_yield_types;
662662
loop_body_ctx.continue_types = &new_params_types;
663663

664-
const Node* nbody = infer(&loop_body_ctx, old_body, wrap_multiple_yield_types(a, new_params_types));
664+
const Node* nbody = infer_case(&loop_body_ctx, old_body, new_params_types);
665665
// TODO check new body params match continue types
666666

667667
return loop_instr(a, (Loop) {
@@ -671,7 +671,7 @@ static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expec
671671
});
672672
}
673673

674-
static const Node* _infer_control(Context* ctx, const Node* node, const Type* expected_type) {
674+
static const Node* infer_control(Context* ctx, const Node* node, const Nodes* expected_types) {
675675
assert(node->tag == Control_TAG);
676676
IrArena* a = ctx->rewriter.dst_arena;
677677

@@ -696,7 +696,7 @@ static const Node* _infer_control(Context* ctx, const Node* node, const Type* ex
696696
});
697697
}
698698

699-
static const Node* _infer_block(Context* ctx, const Node* node, const Type* expected_type) {
699+
static const Node* infer_block(Context* ctx, const Node* node, const Nodes* expected_types) {
700700
assert(node->tag == Block_TAG);
701701
IrArena* a = ctx->rewriter.dst_arena;
702702

@@ -712,35 +712,35 @@ static const Node* _infer_block(Context* ctx, const Node* node, const Type* expe
712712
});
713713
}
714714

715-
static const Node* _infer_instruction(Context* ctx, const Node* node, const Type* expected_type) {
715+
static const Node* infer_instruction(Context* ctx, const Node* node, const Nodes* expected_types) {
716716
switch (is_instruction(node)) {
717-
case PrimOp_TAG: return _infer_primop(ctx, node, expected_type);
718-
case Call_TAG: return _infer_indirect_call(ctx, node, expected_type);
719-
case If_TAG: return _infer_if (ctx, node, expected_type);
720-
case Loop_TAG: return _infer_loop (ctx, node, expected_type);
717+
case PrimOp_TAG: return infer_primop(ctx, node, expected_types);
718+
case Call_TAG: return infer_indirect_call(ctx, node, expected_types);
719+
case If_TAG: return infer_if (ctx, node, expected_types);
720+
case Loop_TAG: return infer_loop (ctx, node, expected_types);
721721
case Match_TAG: error("TODO")
722-
case Control_TAG: return _infer_control(ctx, node, expected_type);
723-
case Block_TAG: return _infer_block (ctx, node, expected_type);
722+
case Control_TAG: return infer_control(ctx, node, expected_types);
723+
case Block_TAG: return infer_block (ctx, node, expected_types);
724724
case Instruction_Comment_TAG: return recreate_node_identity(&ctx->rewriter, node);
725725
default: error("TODO")
726726
case NotAnInstruction: error("not an instruction");
727727
}
728728
SHADY_UNREACHABLE;
729729
}
730730

731-
static const Node* _infer_terminator(Context* ctx, const Node* node) {
731+
static const Node* infer_terminator(Context* ctx, const Node* node) {
732732
IrArena* a = ctx->rewriter.dst_arena;
733733
switch (is_terminator(node)) {
734734
case NotATerminator: assert(false);
735735
case Let_TAG: {
736736
const Node* otail = node->payload.let.tail;
737737
Nodes annotated_types = get_variables_types(a, otail->payload.case_.params);
738-
const Node* inferred_instruction = infer(ctx, node->payload.let.instruction, wrap_multiple_yield_types(a, annotated_types));
738+
const Node* inferred_instruction = infer_instruction(ctx, node->payload.let.instruction, &annotated_types);
739739
Nodes inferred_yield_types = unwrap_multiple_yield_types(a, inferred_instruction->type);
740740
for (size_t i = 0; i < inferred_yield_types.count; i++) {
741741
assert(is_value_type(inferred_yield_types.nodes[i]));
742742
}
743-
const Node* inferred_tail = infer(ctx, otail, wrap_multiple_yield_types(a, inferred_yield_types));
743+
const Node* inferred_tail = infer_case(ctx, otail, inferred_yield_types);
744744
return let(a, inferred_instruction, inferred_tail);
745745
}
746746
case Return_TAG: {
@@ -776,40 +776,49 @@ static const Node* _infer_terminator(Context* ctx, const Node* node) {
776776
case Terminator_Switch_TAG: break;
777777
case Terminator_TailCall_TAG: break;
778778
case Terminator_Yield_TAG: {
779-
const Nodes* expected_types = ctx->merge_types;
780779
// TODO: block nodes should set merge types
781-
assert(expected_types && "Merge terminator found but we're not within a suitable if instruction !");
782-
const Nodes* old_args = &node->payload.yield.args;
783-
assert(expected_types->count == old_args->count);
784-
LARRAY(const Node*, new_args, old_args->count);
785-
for (size_t i = 0; i < old_args->count; i++)
786-
new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]);
780+
assert(ctx->merge_types && "Merge terminator found but we're not within a suitable if instruction !");
781+
Nodes expected_types = *ctx->merge_types;
782+
Nodes old_args = node->payload.yield.args;
783+
assert(expected_types.count == old_args.count);
784+
LARRAY(const Node*, new_args, old_args.count);
785+
for (size_t i = 0; i < old_args.count; i++) {
786+
const Node* e = expected_types.nodes[i];
787+
assert(is_value_type(e));
788+
new_args[i] = infer(ctx, old_args.nodes[i], e);
789+
}
787790
return yield(a, (Yield) {
788-
.args = nodes(a, old_args->count, new_args)
791+
.args = nodes(a, old_args.count, new_args)
789792
});
790793
}
791794
case MergeContinue_TAG: {
792-
const Nodes* expected_types = ctx->continue_types;
793-
assert(expected_types && "Merge terminator found but we're not within a suitable loop instruction !");
794-
const Nodes* old_args = &node->payload.merge_continue.args;
795-
assert(expected_types->count == old_args->count);
796-
LARRAY(const Node*, new_args, old_args->count);
797-
for (size_t i = 0; i < old_args->count; i++)
798-
new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]);
795+
assert(ctx->continue_types && "Merge terminator found but we're not within a suitable loop instruction !");
796+
Nodes expected_types = *ctx->continue_types;
797+
Nodes old_args = node->payload.yield.args;
798+
assert(expected_types.count == old_args.count);
799+
LARRAY(const Node*, new_args, old_args.count);
800+
for (size_t i = 0; i < old_args.count; i++) {
801+
const Node* e = expected_types.nodes[i];
802+
assert(is_value_type(e));
803+
new_args[i] = infer(ctx, old_args.nodes[i], e);
804+
}
799805
return merge_continue(a, (MergeContinue) {
800-
.args = nodes(a, old_args->count, new_args)
806+
.args = nodes(a, old_args.count, new_args)
801807
});
802808
}
803809
case MergeBreak_TAG: {
804-
const Nodes* expected_types = ctx->break_types;
805-
assert(expected_types && "Merge terminator found but we're not within a suitable loop instruction !");
806-
const Nodes* old_args = &node->payload.merge_break.args;
807-
assert(expected_types->count == old_args->count);
808-
LARRAY(const Node*, new_args, old_args->count);
809-
for (size_t i = 0; i < old_args->count; i++)
810-
new_args[i] = infer(ctx, old_args->nodes[i], (*expected_types).nodes[i]);
810+
assert(ctx->break_types && "Merge terminator found but we're not within a suitable loop instruction !");
811+
Nodes expected_types = *ctx->break_types;
812+
Nodes old_args = node->payload.yield.args;
813+
assert(expected_types.count == old_args.count);
814+
LARRAY(const Node*, new_args, old_args.count);
815+
for (size_t i = 0; i < old_args.count; i++) {
816+
const Node* e = expected_types.nodes[i];
817+
assert(is_value_type(e));
818+
new_args[i] = infer(ctx, old_args.nodes[i], e);
819+
}
811820
return merge_break(a, (MergeBreak) {
812-
.args = nodes(a, old_args->count, new_args)
821+
.args = nodes(a, old_args.count, new_args)
813822
});
814823
}
815824
case Unreachable_TAG: return unreachable(a);
@@ -822,7 +831,7 @@ static const Node* _infer_terminator(Context* ctx, const Node* node) {
822831
}
823832

824833
static const Node* process(Context* src_ctx, const Node* node) {
825-
const Type* expect = src_ctx->expected_type;
834+
const Node* expected_type = src_ctx->expected_type;
826835
Context ctx = *src_ctx;
827836
ctx.expected_type = NULL;
828837

@@ -834,25 +843,27 @@ static const Node* process(Context* src_ctx, const Node* node) {
834843
}
835844

836845
if (is_type(node)) {
837-
assert(expect == NULL);
838-
return _infer_type(&ctx, node);
846+
assert(expected_type == NULL);
847+
return infer_type(&ctx, node);
839848
} else if (is_value(node)) {
840-
const Node* value = _infer_value(&ctx, node, expect);
849+
const Node* value = infer_value(&ctx, node, expected_type);
841850
assert(is_value_type(value->type));
842851
return value;
843-
}else if (is_instruction(node))
844-
return _infer_instruction(&ctx, node, expect);
845-
else if (is_terminator(node)) {
846-
assert(expect == NULL);
847-
return _infer_terminator(&ctx, node);
852+
} else if (is_instruction(node)) {
853+
assert(false);
854+
//return infer_instruction(&ctx, node, expected_type);
855+
} else if (is_terminator(node)) {
856+
assert(expected_type == NULL);
857+
return infer_terminator(&ctx, node);
848858
} else if (is_declaration(node)) {
849-
return _infer_decl(&ctx, node);
859+
return infer_decl(&ctx, node);
850860
} else if (is_annotation(node)) {
851-
assert(expect == NULL);
852-
return _infer_annotation(&ctx, node);
861+
assert(expected_type == NULL);
862+
return infer_annotation(&ctx, node);
853863
} else if (is_case(node)) {
854-
assert(expect != NULL);
855-
return _infer_case(&ctx, node, expect);
864+
assert(false);
865+
//assert(expected_types != NULL);
866+
//return infer_case(&ctx, node, expected_types);
856867
} else if (is_basic_block(node)) {
857868
return _infer_basic_block(&ctx, node);
858869
}

src/shady/type_helpers.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ Nodes unwrap_multiple_yield_types(IrArena* arena, const Type* type) {
2424
if (type->payload.record_type.special == MultipleReturn)
2525
return type->payload.record_type.members;
2626
// fallthrough
27-
default: return nodes(arena, 1, (const Node* []) { type });
27+
default:
28+
assert(is_value_type(type));
29+
return nodes(arena, 1, (const Node* []) { type });
2830
}
2931
}
3032

0 commit comments

Comments
 (0)