@@ -36,6 +36,9 @@ typedef struct {
36
36
const Nodes * continue_types ;
37
37
} Context ;
38
38
39
+ static const Node * infer_value (Context * ctx , const Node * node , const Type * expected_type );
40
+ static const Node * infer_instruction (Context * ctx , const Node * node , const Nodes * expected_types );
41
+
39
42
static const Node * infer (Context * ctx , const Node * node , const Type * expect ) {
40
43
Context ctx2 = * ctx ;
41
44
ctx2 .expected_type = expect ;
@@ -51,7 +54,7 @@ static Nodes infer_nodes(Context* ctx, Nodes nodes) {
51
54
#define rewrite_node error("don't use this directly, use the 'infer' and 'infer_node' helpers")
52
55
#define rewrite_nodes rewrite_node
53
56
54
- static const Node * _infer_annotation (Context * ctx , const Node * node ) {
57
+ static const Node * infer_annotation (Context * ctx , const Node * node ) {
55
58
IrArena * a = ctx -> rewriter .dst_arena ;
56
59
assert (is_annotation (node ));
57
60
switch (node -> tag ) {
@@ -63,7 +66,7 @@ static const Node* _infer_annotation(Context* ctx, const Node* node) {
63
66
}
64
67
}
65
68
66
- static const Node * _infer_type (Context * ctx , const Type * type ) {
69
+ static const Node * infer_type (Context * ctx , const Type * type ) {
67
70
IrArena * a = ctx -> rewriter .dst_arena ;
68
71
switch (type -> tag ) {
69
72
case ArrType_TAG : {
@@ -83,7 +86,7 @@ static const Node* _infer_type(Context* ctx, const Type* type) {
83
86
}
84
87
}
85
88
86
- static const Node * _infer_decl (Context * ctx , const Node * node ) {
89
+ static const Node * infer_decl (Context * ctx , const Node * node ) {
87
90
assert (is_declaration (node ));
88
91
const Node * already_done = search_processed (& ctx -> rewriter , node );
89
92
if (already_done )
@@ -117,9 +120,10 @@ static const Node* _infer_decl(Context* ctx, const Node* node) {
117
120
const Node * instruction ;
118
121
if (imported_hint ) {
119
122
assert (is_data_type (imported_hint ));
120
- instruction = infer (ctx , oconstant -> instruction , qualified_type_helper (imported_hint , true));
123
+ Nodes s = singleton (qualified_type_helper (imported_hint , true));
124
+ instruction = infer_instruction (ctx , oconstant -> instruction , & s );
121
125
} else {
122
- instruction = infer (ctx , oconstant -> instruction , NULL );
126
+ instruction = infer_instruction (ctx , oconstant -> instruction , NULL );
123
127
}
124
128
imported_hint = get_unqualified_type (instruction -> type );
125
129
@@ -156,13 +160,9 @@ static const Type* remove_uniformity_qualifier(const Node* type) {
156
160
return type ;
157
161
}
158
162
159
- static const Node * _infer_value (Context * ctx , const Node * node , const Type * expected_type ) {
163
+ static const Node * infer_value (Context * ctx , const Node * node , const Type * expected_type ) {
160
164
if (!node ) return NULL ;
161
165
162
- if (expected_type ) {
163
- assert (is_value_type (expected_type ));
164
- }
165
-
166
166
IrArena * a = ctx -> rewriter .dst_arena ;
167
167
Rewriter * r = & ctx -> rewriter ;
168
168
switch (is_value (node )) {
@@ -302,11 +302,9 @@ static const Node* _infer_value(Context* ctx, const Node* node, const Type* expe
302
302
return recreate_node_identity (& ctx -> rewriter , node );
303
303
}
304
304
305
- static const Node * _infer_case (Context * ctx , const Node * node , const Node * expected ) {
305
+ static const Node * infer_case (Context * ctx , const Node * node , Nodes inferred_arg_type ) {
306
306
IrArena * a = ctx -> rewriter .dst_arena ;
307
307
assert (is_case (node ));
308
- assert (expected );
309
- Nodes inferred_arg_type = unwrap_multiple_yield_types (a , expected );
310
308
assert (inferred_arg_type .count == node -> payload .case_ .params .count || node -> payload .case_ .params .count == 0 );
311
309
312
310
Context body_context = * ctx ;
@@ -319,7 +317,7 @@ static const Node* _infer_case(Context* ctx, const Node* node, const Node* expec
319
317
const Variable * old_param = & node -> payload .case_ .params .nodes [i ]-> payload .var ;
320
318
// for the param type: use the inferred one if none is already provided
321
319
// if one is provided, check the inferred argument type is a subtype of the param type
322
- const Type * param_type = infer (ctx , old_param -> type , NULL ) ;
320
+ const Type * param_type = old_param -> type ? infer_type (ctx , old_param -> type ) : NULL ;
323
321
// and do not use the provided param type if it is an untyped ptr
324
322
if (!param_type || param_type -> tag != PtrType_TAG || param_type -> payload .ptr_type .pointed_type )
325
323
param_type = inferred_arg_type .nodes [i ];
@@ -409,7 +407,7 @@ static void fix_source_pointer(BodyBuilder* bb, const Node** operand, const Type
409
407
}
410
408
}
411
409
412
- static const Node * _infer_primop (Context * ctx , const Node * node , const Type * expected_type ) {
410
+ static const Node * infer_primop (Context * ctx , const Node * node , const Nodes * expected_types ) {
413
411
assert (node -> tag == PrimOp_TAG );
414
412
IrArena * a = ctx -> rewriter .dst_arena ;
415
413
@@ -586,7 +584,7 @@ static const Node* _infer_primop(Context* ctx, const Node* node, const Type* exp
586
584
}
587
585
}
588
586
589
- static const Node * _infer_indirect_call (Context * ctx , const Node * node , const Type * expected_type ) {
587
+ static const Node * infer_indirect_call (Context * ctx , const Node * node , const Nodes * expected_types ) {
590
588
assert (node -> tag == Call_TAG );
591
589
IrArena * a = ctx -> rewriter .dst_arena ;
592
590
@@ -616,21 +614,21 @@ static const Node* _infer_indirect_call(Context* ctx, const Node* node, const Ty
616
614
});
617
615
}
618
616
619
- static const Node * _infer_if (Context * ctx , const Node * node , const Type * expected_type ) {
617
+ static const Node * infer_if (Context * ctx , const Node * node , const Nodes * expected_types ) {
620
618
assert (node -> tag == If_TAG );
621
619
IrArena * a = ctx -> rewriter .dst_arena ;
622
- const Node * condition = infer (ctx , node -> payload .if_instr .condition , bool_type (a ));
620
+ const Node * condition = infer (ctx , node -> payload .if_instr .condition , qualified_type_helper ( bool_type (a ), false ));
623
621
624
622
Nodes join_types = infer_nodes (ctx , node -> payload .if_instr .yield_types );
625
623
Context infer_if_body_ctx = * ctx ;
626
624
// When we infer the types of the arguments to a call to merge(), they are expected to be varying
627
- Nodes expected_join_types = annotate_all_types (a , join_types , false);
625
+ Nodes expected_join_types = add_qualifiers (a , join_types , false);
628
626
infer_if_body_ctx .merge_types = & expected_join_types ;
629
627
630
- const Node * true_body = infer (& infer_if_body_ctx , node -> payload .if_instr .if_true , wrap_multiple_yield_types ( a , nodes (a , 0 , NULL ) ));
628
+ const Node * true_body = infer_case (& infer_if_body_ctx , node -> payload .if_instr .if_true , nodes (a , 0 , NULL ));
631
629
// don't allow seeing the variables made available in the true branch
632
630
infer_if_body_ctx .rewriter = ctx -> rewriter ;
633
- const Node * false_body = node -> payload .if_instr .if_false ? infer (& infer_if_body_ctx , node -> payload .if_instr .if_false , wrap_multiple_yield_types ( a , nodes (a , 0 , NULL ) )) : NULL ;
631
+ const Node * false_body = node -> payload .if_instr .if_false ? infer_case (& infer_if_body_ctx , node -> payload .if_instr .if_false , nodes (a , 0 , NULL )) : NULL ;
634
632
635
633
return if_instr (a , (If ) {
636
634
.yield_types = join_types ,
@@ -640,7 +638,7 @@ static const Node* _infer_if(Context* ctx, const Node* node, const Type* expecte
640
638
});
641
639
}
642
640
643
- static const Node * _infer_loop (Context * ctx , const Node * node , const Type * expected_type ) {
641
+ static const Node * infer_loop (Context * ctx , const Node * node , const Nodes * expected_types ) {
644
642
assert (node -> tag == Loop_TAG );
645
643
IrArena * a = ctx -> rewriter .dst_arena ;
646
644
Context loop_body_ctx = * ctx ;
@@ -649,19 +647,21 @@ static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expec
649
647
Nodes old_params = get_abstraction_params (old_body );
650
648
Nodes old_params_types = get_variables_types (a , old_params );
651
649
Nodes new_params_types = infer_nodes (ctx , old_params_types );
650
+ new_params_types = annotate_all_types (a , new_params_types , false);
652
651
653
652
Nodes old_initial_args = node -> payload .loop_instr .initial_args ;
654
653
LARRAY (const Node * , new_initial_args , old_params .count );
655
654
for (size_t i = 0 ; i < old_params .count ; i ++ )
656
655
new_initial_args [i ] = infer (ctx , old_initial_args .nodes [i ], new_params_types .nodes [i ]);
657
656
658
657
Nodes loop_yield_types = infer_nodes (ctx , node -> payload .loop_instr .yield_types );
658
+ Nodes qual_yield_types = add_qualifiers (a , loop_yield_types , false);
659
659
660
660
loop_body_ctx .merge_types = NULL ;
661
- loop_body_ctx .break_types = & loop_yield_types ;
661
+ loop_body_ctx .break_types = & qual_yield_types ;
662
662
loop_body_ctx .continue_types = & new_params_types ;
663
663
664
- const Node * nbody = infer (& loop_body_ctx , old_body , wrap_multiple_yield_types ( a , new_params_types ) );
664
+ const Node * nbody = infer_case (& loop_body_ctx , old_body , new_params_types );
665
665
// TODO check new body params match continue types
666
666
667
667
return loop_instr (a , (Loop ) {
@@ -671,7 +671,7 @@ static const Node* _infer_loop(Context* ctx, const Node* node, const Type* expec
671
671
});
672
672
}
673
673
674
- static const Node * _infer_control (Context * ctx , const Node * node , const Type * expected_type ) {
674
+ static const Node * infer_control (Context * ctx , const Node * node , const Nodes * expected_types ) {
675
675
assert (node -> tag == Control_TAG );
676
676
IrArena * a = ctx -> rewriter .dst_arena ;
677
677
@@ -696,7 +696,7 @@ static const Node* _infer_control(Context* ctx, const Node* node, const Type* ex
696
696
});
697
697
}
698
698
699
- static const Node * _infer_block (Context * ctx , const Node * node , const Type * expected_type ) {
699
+ static const Node * infer_block (Context * ctx , const Node * node , const Nodes * expected_types ) {
700
700
assert (node -> tag == Block_TAG );
701
701
IrArena * a = ctx -> rewriter .dst_arena ;
702
702
@@ -712,35 +712,35 @@ static const Node* _infer_block(Context* ctx, const Node* node, const Type* expe
712
712
});
713
713
}
714
714
715
- static const Node * _infer_instruction (Context * ctx , const Node * node , const Type * expected_type ) {
715
+ static const Node * infer_instruction (Context * ctx , const Node * node , const Nodes * expected_types ) {
716
716
switch (is_instruction (node )) {
717
- case PrimOp_TAG : return _infer_primop (ctx , node , expected_type );
718
- case Call_TAG : return _infer_indirect_call (ctx , node , expected_type );
719
- case If_TAG : return _infer_if (ctx , node , expected_type );
720
- case Loop_TAG : return _infer_loop (ctx , node , expected_type );
717
+ case PrimOp_TAG : return infer_primop (ctx , node , expected_types );
718
+ case Call_TAG : return infer_indirect_call (ctx , node , expected_types );
719
+ case If_TAG : return infer_if (ctx , node , expected_types );
720
+ case Loop_TAG : return infer_loop (ctx , node , expected_types );
721
721
case Match_TAG : error ("TODO" )
722
- case Control_TAG : return _infer_control (ctx , node , expected_type );
723
- case Block_TAG : return _infer_block (ctx , node , expected_type );
722
+ case Control_TAG : return infer_control (ctx , node , expected_types );
723
+ case Block_TAG : return infer_block (ctx , node , expected_types );
724
724
case Instruction_Comment_TAG : return recreate_node_identity (& ctx -> rewriter , node );
725
725
default : error ("TODO" )
726
726
case NotAnInstruction : error ("not an instruction" );
727
727
}
728
728
SHADY_UNREACHABLE ;
729
729
}
730
730
731
- static const Node * _infer_terminator (Context * ctx , const Node * node ) {
731
+ static const Node * infer_terminator (Context * ctx , const Node * node ) {
732
732
IrArena * a = ctx -> rewriter .dst_arena ;
733
733
switch (is_terminator (node )) {
734
734
case NotATerminator : assert (false);
735
735
case Let_TAG : {
736
736
const Node * otail = node -> payload .let .tail ;
737
737
Nodes annotated_types = get_variables_types (a , otail -> payload .case_ .params );
738
- const Node * inferred_instruction = infer (ctx , node -> payload .let .instruction , wrap_multiple_yield_types ( a , annotated_types ) );
738
+ const Node * inferred_instruction = infer_instruction (ctx , node -> payload .let .instruction , & annotated_types );
739
739
Nodes inferred_yield_types = unwrap_multiple_yield_types (a , inferred_instruction -> type );
740
740
for (size_t i = 0 ; i < inferred_yield_types .count ; i ++ ) {
741
741
assert (is_value_type (inferred_yield_types .nodes [i ]));
742
742
}
743
- const Node * inferred_tail = infer (ctx , otail , wrap_multiple_yield_types ( a , inferred_yield_types ) );
743
+ const Node * inferred_tail = infer_case (ctx , otail , inferred_yield_types );
744
744
return let (a , inferred_instruction , inferred_tail );
745
745
}
746
746
case Return_TAG : {
@@ -776,40 +776,49 @@ static const Node* _infer_terminator(Context* ctx, const Node* node) {
776
776
case Terminator_Switch_TAG : break ;
777
777
case Terminator_TailCall_TAG : break ;
778
778
case Terminator_Yield_TAG : {
779
- const Nodes * expected_types = ctx -> merge_types ;
780
779
// TODO: block nodes should set merge types
781
- assert (expected_types && "Merge terminator found but we're not within a suitable if instruction !" );
782
- const Nodes * old_args = & node -> payload .yield .args ;
783
- assert (expected_types -> count == old_args -> count );
784
- LARRAY (const Node * , new_args , old_args -> count );
785
- for (size_t i = 0 ; i < old_args -> count ; i ++ )
786
- new_args [i ] = infer (ctx , old_args -> nodes [i ], (* expected_types ).nodes [i ]);
780
+ assert (ctx -> merge_types && "Merge terminator found but we're not within a suitable if instruction !" );
781
+ Nodes expected_types = * ctx -> merge_types ;
782
+ Nodes old_args = node -> payload .yield .args ;
783
+ assert (expected_types .count == old_args .count );
784
+ LARRAY (const Node * , new_args , old_args .count );
785
+ for (size_t i = 0 ; i < old_args .count ; i ++ ) {
786
+ const Node * e = expected_types .nodes [i ];
787
+ assert (is_value_type (e ));
788
+ new_args [i ] = infer (ctx , old_args .nodes [i ], e );
789
+ }
787
790
return yield (a , (Yield ) {
788
- .args = nodes (a , old_args -> count , new_args )
791
+ .args = nodes (a , old_args . count , new_args )
789
792
});
790
793
}
791
794
case MergeContinue_TAG : {
792
- const Nodes * expected_types = ctx -> continue_types ;
793
- assert (expected_types && "Merge terminator found but we're not within a suitable loop instruction !" );
794
- const Nodes * old_args = & node -> payload .merge_continue .args ;
795
- assert (expected_types -> count == old_args -> count );
796
- LARRAY (const Node * , new_args , old_args -> count );
797
- for (size_t i = 0 ; i < old_args -> count ; i ++ )
798
- new_args [i ] = infer (ctx , old_args -> nodes [i ], (* expected_types ).nodes [i ]);
795
+ assert (ctx -> continue_types && "Merge terminator found but we're not within a suitable loop instruction !" );
796
+ Nodes expected_types = * ctx -> continue_types ;
797
+ Nodes old_args = node -> payload .yield .args ;
798
+ assert (expected_types .count == old_args .count );
799
+ LARRAY (const Node * , new_args , old_args .count );
800
+ for (size_t i = 0 ; i < old_args .count ; i ++ ) {
801
+ const Node * e = expected_types .nodes [i ];
802
+ assert (is_value_type (e ));
803
+ new_args [i ] = infer (ctx , old_args .nodes [i ], e );
804
+ }
799
805
return merge_continue (a , (MergeContinue ) {
800
- .args = nodes (a , old_args -> count , new_args )
806
+ .args = nodes (a , old_args . count , new_args )
801
807
});
802
808
}
803
809
case MergeBreak_TAG : {
804
- const Nodes * expected_types = ctx -> break_types ;
805
- assert (expected_types && "Merge terminator found but we're not within a suitable loop instruction !" );
806
- const Nodes * old_args = & node -> payload .merge_break .args ;
807
- assert (expected_types -> count == old_args -> count );
808
- LARRAY (const Node * , new_args , old_args -> count );
809
- for (size_t i = 0 ; i < old_args -> count ; i ++ )
810
- new_args [i ] = infer (ctx , old_args -> nodes [i ], (* expected_types ).nodes [i ]);
810
+ assert (ctx -> break_types && "Merge terminator found but we're not within a suitable loop instruction !" );
811
+ Nodes expected_types = * ctx -> break_types ;
812
+ Nodes old_args = node -> payload .yield .args ;
813
+ assert (expected_types .count == old_args .count );
814
+ LARRAY (const Node * , new_args , old_args .count );
815
+ for (size_t i = 0 ; i < old_args .count ; i ++ ) {
816
+ const Node * e = expected_types .nodes [i ];
817
+ assert (is_value_type (e ));
818
+ new_args [i ] = infer (ctx , old_args .nodes [i ], e );
819
+ }
811
820
return merge_break (a , (MergeBreak ) {
812
- .args = nodes (a , old_args -> count , new_args )
821
+ .args = nodes (a , old_args . count , new_args )
813
822
});
814
823
}
815
824
case Unreachable_TAG : return unreachable (a );
@@ -822,7 +831,7 @@ static const Node* _infer_terminator(Context* ctx, const Node* node) {
822
831
}
823
832
824
833
static const Node * process (Context * src_ctx , const Node * node ) {
825
- const Type * expect = src_ctx -> expected_type ;
834
+ const Node * expected_type = src_ctx -> expected_type ;
826
835
Context ctx = * src_ctx ;
827
836
ctx .expected_type = NULL ;
828
837
@@ -834,25 +843,27 @@ static const Node* process(Context* src_ctx, const Node* node) {
834
843
}
835
844
836
845
if (is_type (node )) {
837
- assert (expect == NULL );
838
- return _infer_type (& ctx , node );
846
+ assert (expected_type == NULL );
847
+ return infer_type (& ctx , node );
839
848
} else if (is_value (node )) {
840
- const Node * value = _infer_value (& ctx , node , expect );
849
+ const Node * value = infer_value (& ctx , node , expected_type );
841
850
assert (is_value_type (value -> type ));
842
851
return value ;
843
- }else if (is_instruction (node ))
844
- return _infer_instruction (& ctx , node , expect );
845
- else if (is_terminator (node )) {
846
- assert (expect == NULL );
847
- return _infer_terminator (& ctx , node );
852
+ } else if (is_instruction (node )) {
853
+ assert (false);
854
+ //return infer_instruction(&ctx, node, expected_type);
855
+ } else if (is_terminator (node )) {
856
+ assert (expected_type == NULL );
857
+ return infer_terminator (& ctx , node );
848
858
} else if (is_declaration (node )) {
849
- return _infer_decl (& ctx , node );
859
+ return infer_decl (& ctx , node );
850
860
} else if (is_annotation (node )) {
851
- assert (expect == NULL );
852
- return _infer_annotation (& ctx , node );
861
+ assert (expected_type == NULL );
862
+ return infer_annotation (& ctx , node );
853
863
} else if (is_case (node )) {
854
- assert (expect != NULL );
855
- return _infer_case (& ctx , node , expect );
864
+ assert (false);
865
+ //assert(expected_types != NULL);
866
+ //return infer_case(&ctx, node, expected_types);
856
867
} else if (is_basic_block (node )) {
857
868
return _infer_basic_block (& ctx , node );
858
869
}
0 commit comments