Skip to content

Commit e0c3c22

Browse files
committed
Effects: cleaner and efficient substitutions
1 parent 646a3dd commit e0c3c22

File tree

1 file changed

+104
-122
lines changed

1 file changed

+104
-122
lines changed

compiler/lib/effects.ml

Lines changed: 104 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -661,19 +661,18 @@ let duplicate_code ~st pc =
661661
st.new_blocks <- new_blocks, free_pc;
662662
Addr.Map.find pc new_pc_of_old
663663

664-
let cps_instr ~st ~lifter_functions (instr : instr) : instr list * Var.t Var.Map.t =
664+
let cps_instr ~st ~lifter_functions (instr : instr) : instr list =
665665
match instr with
666666
| Let (x, Closure (params, ((pc, _) as cont)))
667667
when Var.Set.mem x st.cps_needed && not (Var.Set.mem x !(st.single_version_closures))
668668
->
669669
let direct_c = Var.fork x in
670670
let cps_c = Var.fork x in
671671
let cps_params, cps_cont = Hashtbl.find st.closure_info pc in
672-
( [ Let (direct_c, Closure (params, cont))
673-
; Let (cps_c, Closure (cps_params, cps_cont))
674-
; Let (x, Prim (Extern "caml_cps_closure", [ Pv direct_c; Pv cps_c ]))
675-
]
676-
, Var.Map.empty )
672+
[ Let (direct_c, Closure (params, cont))
673+
; Let (cps_c, Closure (cps_params, cps_cont))
674+
; Let (x, Prim (Extern "caml_cps_closure", [ Pv direct_c; Pv cps_c ]))
675+
]
677676
| Let (x, Closure (params, (pc, args)))
678677
when (not (Var.Set.mem x st.cps_needed))
679678
&& not (Var.Set.mem x !(st.single_version_closures)) ->
@@ -684,17 +683,16 @@ let cps_instr ~st ~lifter_functions (instr : instr) : instr list * Var.t Var.Map
684683
let new_pc = duplicate_code ~st pc in
685684
(* We leave [params] and [args] unchanged here because they will be
686685
replaced with fresh variables in a later, global substitution pass. *)
687-
[ Let (x, Closure (params, (new_pc, args))) ], Var.Map.empty
686+
[ Let (x, Closure (params, (new_pc, args))) ]
688687
| Let (x, Prim (Extern "caml_alloc_dummy_function", [ size; arity ])) -> (
689688
match arity with
690689
| Pc (Int a) ->
691-
( [ Let
692-
( x
693-
, Prim
694-
(Extern "caml_alloc_dummy_function", [ size; Pc (Int (Int32.succ a)) ])
695-
)
696-
]
697-
, Var.Map.empty )
690+
[ Let
691+
( x
692+
, Prim
693+
(Extern "caml_alloc_dummy_function", [ size; Pc (Int (Int32.succ a)) ])
694+
)
695+
]
698696
| _ -> assert false)
699697
| Let (x, Apply { f; args; _ }) when not (Var.Set.mem x st.cps_needed) ->
700698
(* At the moment, we turn into CPS any function not called with
@@ -704,19 +702,13 @@ let cps_instr ~st ~lifter_functions (instr : instr) : instr list * Var.t Var.Map
704702
introduced by the lambda lifting and does not require CPS *)
705703
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
706704
|| Global_flow.exact_call st.flow_info f (List.length args));
707-
[ Let (x, Apply { f; args; exact = true }) ], Var.Map.empty
705+
[ Let (x, Apply { f; args; exact = true }) ]
708706
| Let (_, Apply { f; args = _; exact = _ }) when Var.Set.mem f lifter_functions ->
709707
(* Nothing to do for lifter functions. *)
710-
[ instr ], Var.Map.empty
708+
[ instr ]
711709
| Let (_, (Apply _ | Prim (Extern ("%resume" | "%perform" | "%reperform"), _))) ->
712710
assert false
713-
| _ -> [ instr ], Var.Map.empty
714-
715-
let concat_union : ('a list * 'b Var.Map.t) list -> 'a list * 'b Var.Map.t =
716-
List.fold_left
717-
~f:(fun (instrs, subst) (is, s) ->
718-
instrs @ is, Var.Map.union (fun _ _ -> assert false) subst s)
719-
~init:([], Var.Map.empty)
711+
| _ -> [ instr ]
720712

721713
let cps_block ~st ~k ~lifter_functions ~orig_pc block =
722714
debug_print "cps_block %d\n" orig_pc;
@@ -804,40 +796,39 @@ let cps_block ~st ~k ~lifter_functions ~orig_pc block =
804796
| None, _ -> None
805797
in
806798

807-
let body, last, subst =
799+
let body, last =
808800
match rewritten_block with
809801
| Some (body_prefix, last_instrs, last) ->
810-
let body_prefix, subst =
802+
let body_prefix =
811803
(* For each instruction... *)
812804
List.map body_prefix ~f:(fun (i, loc) ->
813805
(* ... apply [cps_instr] ... *)
814-
let instrs, subst = cps_instr ~st ~lifter_functions i in
806+
cps_instr ~st ~lifter_functions i
815807
(* ... and decorate all resulting instructions with [loc] *)
816-
List.map ~f:(fun i -> i, loc) instrs, subst)
817-
|> concat_union (* Merge the resulting variable substitutions into one *)
808+
|> List.map ~f:(fun i -> i, loc))
809+
|> List.concat
818810
in
819-
body_prefix @ last_instrs, last, subst
811+
body_prefix @ last_instrs, last
820812
| None ->
821813
let last_instrs, last =
822814
cps_last ~st ~alloc_jump_closures orig_pc block.branch ~k
823815
in
824-
let body, subst =
816+
let body =
825817
(* For each instruction... *)
826818
List.map block.body ~f:(fun (i, loc) ->
827819
(* ... apply [cps_instr] ... *)
828-
let instrs, subst = cps_instr ~st ~lifter_functions i in
820+
cps_instr ~st ~lifter_functions i
829821
(* ... and decorate all resulting instructions with [loc] *)
830-
List.map ~f:(fun i -> i, loc) instrs, subst)
831-
|> concat_union (* Merge the resulting variable substitutions into one *)
822+
|> List.map ~f:(fun i -> i, loc))
823+
|> List.concat
832824
in
833-
body @ last_instrs, last, subst
825+
body @ last_instrs, last
834826
in
835827

836-
( { params = (if Addr.Set.mem orig_pc st.blocks_to_transform then [] else block.params)
837-
; body
838-
; branch = last
839-
}
840-
, subst )
828+
{ params = (if Addr.Set.mem orig_pc st.blocks_to_transform then [] else block.params)
829+
; body
830+
; branch = last
831+
}
841832

842833
(* Modify all function applications and closure creations to take into account
843834
the fact that closures are turned (direct style, CPS) closure pairs. Also
@@ -849,66 +840,50 @@ let cps_block ~st ~k ~lifter_functions ~orig_pc block =
849840
exactly once. This is done by creating fresh arguments for each CPS closure
850841
and returning a substitution from the original parameters to the new ones,
851842
that must be applied to all code that might use the original parameters. *)
852-
let rewrite_direct_block
853-
~cps_needed
854-
~closure_info
855-
~ident_fn
856-
~pc
857-
~lifter_functions
858-
~subst
859-
block =
843+
let rewrite_direct_block ~cps_needed ~closure_info ~ident_fn ~pc ~lifter_functions block =
860844
debug_print "@[<v>rewrite_direct_block %d@,@]" pc;
861-
let rewrite_instr subst = function
845+
let rewrite_instr = function
862846
| Let (x, Closure (params, ((pc, _) as cont)))
863847
when Var.Set.mem x cps_needed && not (Var.Set.mem x lifter_functions) ->
864848
let direct_c = Var.fork x in
865849
let cps_c = Var.fork x in
866850
let cps_params, cps_cont = Hashtbl.find closure_info pc in
867-
( [ Let (direct_c, Closure (params, cont))
868-
; Let (cps_c, Closure (cps_params, cps_cont))
869-
; Let (x, Prim (Extern "caml_cps_closure", [ Pv direct_c; Pv cps_c ]))
870-
]
871-
, subst )
851+
[ Let (direct_c, Closure (params, cont))
852+
; Let (cps_c, Closure (cps_params, cps_cont))
853+
; Let (x, Prim (Extern "caml_cps_closure", [ Pv direct_c; Pv cps_c ]))
854+
]
872855
| Let (x, Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg ])) ->
873856
(* Pass the identity as a continuation and pass to
874857
[caml_trampoline_cps], which will 1. install a trampoline, 2. call
875858
the CPS version of [f] and 3. handle exceptions. *)
876859
let k = Var.fresh_n "cont" in
877860
let args = Var.fresh_n "args" in
878-
( [ Let (k, Prim (Extern "caml_resume_stack", [ Pv stack; Pv ident_fn ]))
879-
; Let (args, Prim (Extern "%js_array", [ Pv arg; Pv k ]))
880-
; Let (x, Prim (Extern "caml_trampoline_cps", [ Pv f; Pv args ]))
881-
]
882-
, subst )
861+
[ Let (k, Prim (Extern "caml_resume_stack", [ Pv stack; Pv ident_fn ]))
862+
; Let (args, Prim (Extern "%js_array", [ Pv arg; Pv k ]))
863+
; Let (x, Prim (Extern "caml_trampoline_cps", [ Pv f; Pv args ]))
864+
]
883865
| Let (x, Prim (Extern "%perform", [ Pv effect ])) ->
884866
(* Perform the effect, which should call the "Unhandled effect" handler. *)
885867
let k = Int 0l in
886868
(* Dummy continuation *)
887-
( [ Let (x, Prim (Extern "caml_perform_effect", [ Pv effect; Pc (Int 0l); Pc k ]))
888-
]
889-
, subst )
869+
[ Let (x, Prim (Extern "caml_perform_effect", [ Pv effect; Pc (Int 0l); Pc k ])) ]
890870
| Let (x, Prim (Extern "%reperform", [ Pv effect; Pv continuation ])) ->
891871
(* Similar to previous case *)
892872
let k = Int 0l in
893-
( [ Let
894-
( x
895-
, Prim (Extern "caml_perform_effect", [ Pv effect; Pv continuation; Pc k ])
896-
)
897-
]
898-
, subst )
899-
| (Let _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _) as instr ->
900-
[ instr ], subst
873+
[ Let
874+
(x, Prim (Extern "caml_perform_effect", [ Pv effect; Pv continuation; Pc k ]))
875+
]
876+
| (Let _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _) as instr -> [ instr ]
901877
in
902-
let subst, body =
878+
let body =
903879
(* For each instruction... *)
904-
List.fold_left_map block.body ~init:subst ~f:(fun subst (i, loc) ->
880+
List.concat_map block.body ~f:(fun (i, loc) ->
905881
(* ... apply [rewrite_instr] ... *)
906-
let instrs, subst = rewrite_instr subst i in
882+
rewrite_instr i
907883
(* ... and decorate all resulting instructions with [loc] *)
908-
subst, List.map ~f:(fun i -> i, loc) instrs)
884+
|> List.map ~f:(fun i -> i, loc))
909885
in
910-
let body = List.concat body in
911-
{ block with body }, subst
886+
{ block with body }
912887

913888
(* Apply a substitution in a set of blocks *)
914889
let subst_in_blocks blocks s =
@@ -953,10 +928,13 @@ let cps_transform ~lifter_functions ~live_vars ~flow_info ~cps_needed p =
953928
let cps_calls = ref Var.Set.empty in
954929
let single_version_closures = ref lifter_functions in
955930
let cps_pc_of_direct = Hashtbl.create 512 in
956-
let p =
931+
let p, bound_subst, param_subst, new_blocks =
957932
Code.fold_closures_innermost_first
958933
p
959-
(fun name_opt params (start, args) ({ blocks; free_pc; _ } as p) ->
934+
(fun name_opt
935+
params
936+
(start, args)
937+
(({ blocks; free_pc; _ } as p), bound_subst, param_subst, new_blocks) ->
960938
Option.iter name_opt ~f:(fun v -> debug_print "cname = %s" @@ Var.to_string v);
961939
(* We speculatively add a block at the beginning of the
962940
function. In case of tail-recursion optimization, the
@@ -1041,7 +1019,7 @@ let cps_transform ~lifter_functions ~live_vars ~flow_info ~cps_needed p =
10411019
start
10421020
blocks
10431021
());
1044-
let blocks, free_pc =
1022+
let blocks, free_pc, bound_subst, param_subst, new_blocks =
10451023
(* For every block in the closure,
10461024
1. add its CPS translation to the block map at a fresh address, if
10471025
needed
@@ -1055,101 +1033,105 @@ let cps_transform ~lifter_functions ~live_vars ~flow_info ~cps_needed p =
10551033
let k = Var.fresh_n "cont" in
10561034
let cps_start = mk_cps_pc_of_direct ~st start in
10571035
let params' = List.map ~f:Var.fork params in
1058-
let subst =
1036+
let param_subst =
10591037
List.fold_left2
10601038
~f:(fun m p p' -> Var.Map.add p p' m)
1061-
~init:Var.Map.empty
1039+
~init:param_subst
10621040
params
10631041
params'
10641042
in
1065-
let cps_args = List.map ~f:(Subst.from_map subst) args in
1043+
let cps_args = List.map ~f:(Subst.from_map param_subst) args in
10661044
Hashtbl.add
10671045
st.closure_info
10681046
initial_start
10691047
(params' @ [ k ], (cps_start, cps_args));
1070-
( subst
1048+
( param_subst
10711049
, fun pc block ->
1072-
let cps_block, subst =
1073-
cps_block ~st ~lifter_functions ~k ~orig_pc:pc block
1074-
in
1050+
let cps_block = cps_block ~st ~lifter_functions ~k ~orig_pc:pc block in
10751051
( rewrite_direct_block
10761052
~cps_needed
10771053
~closure_info:st.closure_info
10781054
~ident_fn
10791055
~pc
10801056
~lifter_functions
1081-
~subst
10821057
block
10831058
, Some cps_block ) ))
10841059
else
1085-
( Var.Map.empty
1060+
( param_subst
10861061
, fun pc block ->
10871062
( rewrite_direct_block
10881063
~cps_needed
10891064
~closure_info:st.closure_info
10901065
~ident_fn
10911066
~pc
10921067
~lifter_functions
1093-
~subst:Var.Map.empty
10941068
block
10951069
, None ) )
10961070
in
1097-
let blocks, direct_subst =
1071+
let blocks =
10981072
Code.traverse
10991073
{ fold = Code.fold_children }
1100-
(fun pc (blocks, direct_subst) ->
1101-
let (block, s), cps_block_opt =
1102-
transform_block pc (Addr.Map.find pc blocks)
1103-
in
1104-
let s = Var.Map.union (fun _ _ -> assert false) direct_subst s in
1074+
(fun pc blocks ->
1075+
let block, cps_block_opt = transform_block pc (Addr.Map.find pc blocks) in
11051076
let blocks = Addr.Map.add pc block blocks in
11061077
match cps_block_opt with
1107-
| None -> blocks, s
1078+
| None -> blocks
11081079
| Some b ->
11091080
let cps_pc = mk_cps_pc_of_direct ~st pc in
11101081
let new_blocks, free_pc = st.new_blocks in
11111082
st.new_blocks <- Addr.Map.add cps_pc b new_blocks, free_pc;
1112-
Addr.Map.add cps_pc b blocks, s)
1083+
Addr.Map.add cps_pc b blocks)
11131084
start
11141085
st.blocks
1115-
(st.blocks, Var.Map.empty)
1086+
st.blocks
11161087
in
1117-
let cps_blocks, free_pc = st.new_blocks in
1088+
let new_blocks_this_clos, free_pc = st.new_blocks in
11181089
(* Substitute all variables bound in the CPS version with fresh
11191090
variables to avoid clashing with the definitions in the original
11201091
blocks. *)
11211092
let bound =
11221093
Addr.Map.fold
11231094
(fun _ block bound ->
11241095
Var.Set.union bound (Freevars.block_bound_vars ~closure_params:true block))
1125-
cps_blocks
1096+
new_blocks_this_clos
11261097
Var.Set.empty
11271098
in
1128-
let s =
1129-
Var.Set.fold (fun v m -> Var.Map.add v (Var.fork v) m) bound Var.Map.empty
1130-
|> Subst.from_map
1099+
let bound_subst =
1100+
Var.Set.fold (fun v m -> Var.Map.add v (Var.fork v) m) bound bound_subst
11311101
in
1132-
let cps_blocks = subst_bound_in_blocks cps_blocks s in
1133-
(* Also apply susbstitution to set of CPS calls and lifter functions *)
1134-
st.cps_calls := Var.Set.map s !(st.cps_calls);
1135-
st.single_version_closures := Var.Set.map s !(st.single_version_closures);
1136-
(* All variables that were a closure parameter in a direct-style block must be
1137-
substituted by the CPS version of that parameter in CPS blocks (generated by
1138-
[rewrite_direct], because CPS closures are only ever defined in (toplevel)
1139-
direct-style blocks). *)
1140-
let subst =
1141-
Subst.from_map
1142-
@@ Var.Map.union (fun _ _ -> assert false) direct_subst param_subst
1143-
in
1144-
let cps_blocks = subst_in_blocks cps_blocks subst in
1145-
(* Also apply susbstitution to set of CPS calls and lifter functions *)
1146-
st.cps_calls := Var.Set.map subst !(st.cps_calls);
1147-
st.single_version_closures := Var.Set.map subst !(st.single_version_closures);
1148-
let blocks = Addr.Map.fold Addr.Map.add cps_blocks blocks in
1149-
blocks, free_pc
1102+
let blocks = Addr.Map.fold Addr.Map.add new_blocks_this_clos blocks in
1103+
( blocks
1104+
, free_pc
1105+
, bound_subst
1106+
, param_subst
1107+
, Addr.Map.union (fun _ _ -> assert false) new_blocks new_blocks_this_clos )
11501108
in
1151-
{ p with blocks; free_pc })
1152-
p
1109+
{ p with blocks; free_pc }, bound_subst, param_subst, new_blocks)
1110+
(p, Var.Map.empty, Var.Map.empty, Addr.Map.empty)
1111+
in
1112+
let bound_subst = Subst.from_map bound_subst in
1113+
let new_blocks = subst_bound_in_blocks new_blocks bound_subst in
1114+
(* Also apply that substitution to the sets of CPS calls and lifter functions *)
1115+
cps_calls := Var.Set.map bound_subst !cps_calls;
1116+
single_version_closures := Var.Set.map bound_subst !single_version_closures;
1117+
(* All variables that were a closure parameter in a direct-style block must be
1118+
substituted by a fresh name. *)
1119+
let param_subst = Subst.from_map param_subst in
1120+
let new_blocks = subst_in_blocks new_blocks param_subst in
1121+
(* Also apply that 2nd substitution to the sets of CPS calls and lifter functions *)
1122+
cps_calls := Var.Set.map param_subst !cps_calls;
1123+
single_version_closures := Var.Set.map param_subst !single_version_closures;
1124+
let p =
1125+
{ p with
1126+
blocks =
1127+
Addr.Map.merge
1128+
(fun _ a b ->
1129+
match a, b with
1130+
| _, Some b -> Some b
1131+
| a, None -> a)
1132+
p.blocks
1133+
new_blocks
1134+
}
11531135
in
11541136
let p =
11551137
match Hashtbl.find_opt closure_info p.start with

0 commit comments

Comments
 (0)