Skip to content

Commit a644f2d

Browse files
committed
Support values with arbitrary types in function environments
1 parent 11e0e52 commit a644f2d

File tree

5 files changed

+110
-32
lines changed

5 files changed

+110
-32
lines changed

compiler/lib-wasm/closure_conversion.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ open Code
2222
type closure =
2323
{ functions : (Var.t * int) list
2424
; free_variables : Var.t list
25+
; mutable id : int option
2526
}
2627

2728
module SCC = Strongly_connected_components.Make (Var)
@@ -144,7 +145,8 @@ let rec traverse var_depth closures program pc depth =
144145
in
145146
List.iter
146147
~f:(fun (f, _) ->
147-
closures := Var.Map.add f { functions; free_variables } !closures)
148+
closures :=
149+
Var.Map.add f { functions; free_variables; id = None } !closures)
148150
functions;
149151
fun_lst)
150152
components

compiler/lib-wasm/closure_conversion.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
type closure =
2020
{ functions : (Code.Var.t * int) list
2121
; free_variables : Code.Var.t list
22+
; mutable id : int option
2223
}
2324

2425
val f : Code.program -> Code.program * closure Code.Var.Map.t

compiler/lib-wasm/code_generation.ml

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ https://github.com/llvm/llvm-project/issues/58438
3434
type constant_global =
3535
{ init : W.expression option
3636
; constant : bool
37+
; typ : W.value_type
3738
}
3839

3940
type context =
@@ -46,6 +47,7 @@ type context =
4647
; types : Wasm_ast.type_field Var.Hashtbl.t
4748
; mutable closure_envs : Var.t Var.Map.t
4849
(** GC: mapping of recursive functions to their shared environment *)
50+
; closure_types : (W.value_type option list, int) Hashtbl.t
4951
; mutable apply_funs : Var.t IntMap.t
5052
; mutable cps_apply_funs : Var.t IntMap.t
5153
; mutable curry_funs : Var.t IntMap.t
@@ -68,6 +70,7 @@ let make_context ~value_type =
6870
; type_names = String.Hashtbl.create 128
6971
; types = Var.Hashtbl.create 128
7072
; closure_envs = Var.Map.empty
73+
; closure_types = Poly.Hashtbl.create 128
7174
; apply_funs = IntMap.empty
7275
; cps_apply_funs = IntMap.empty
7376
; curry_funs = IntMap.empty
@@ -198,6 +201,7 @@ let register_global name ?exported_name ?(constant = false) typ init st =
198201
name
199202
{ init = (if not typ.mut then Some init else None)
200203
; constant = (not typ.mut) || constant
204+
; typ = typ.typ
201205
}
202206
st.context.constant_globals;
203207
(), st
@@ -484,6 +488,68 @@ let load x =
484488
| Local (_, x, _) -> return (W.LocalGet x)
485489
| Expr e -> e
486490

491+
let rec variable_type x st =
492+
match Var.Map.find_opt x st.vars with
493+
| Some (Local (_, _, typ)) -> typ, st
494+
| Some (Expr e) ->
495+
(let* e = e in
496+
expression_type e)
497+
st
498+
| None -> None, st
499+
500+
and expression_type (e : W.expression) st =
501+
match e with
502+
| Const _
503+
| UnOp _
504+
| BinOp _
505+
| I32WrapI64 _
506+
| I64ExtendI32 _
507+
| F32DemoteF64 _
508+
| F64PromoteF32 _
509+
| BlockExpr _
510+
| Call _
511+
| RefFunc _
512+
| Call_ref _
513+
| I31Get _
514+
| ArrayGet _
515+
| ArrayLen _
516+
| RefTest _
517+
| RefEq _
518+
| RefNull _
519+
| Try _
520+
| Br_on_null _ -> None, st
521+
| LocalGet x | LocalTee (x, _) -> variable_type x st
522+
| GlobalGet x ->
523+
( (try
524+
let typ = (Var.Map.find x st.context.constant_globals).typ in
525+
if Poly.equal typ st.context.value_type
526+
then None
527+
else
528+
Some
529+
(match typ with
530+
| Ref { typ; nullable = true } -> Ref { typ; nullable = false }
531+
| _ -> typ)
532+
with Not_found -> None)
533+
, st )
534+
| Seq (_, e') -> expression_type e' st
535+
| Pop typ -> Some typ, st
536+
| RefI31 _ -> Some (Ref { nullable = false; typ = I31 }), st
537+
| ArrayNew (ty, _, _)
538+
| ArrayNewFixed (ty, _)
539+
| ArrayNewData (ty, _, _, _)
540+
| StructNew (ty, _) -> Some (Ref { nullable = false; typ = Type ty }), st
541+
| StructGet (_, ty, i, _) -> (
542+
match (Var.Hashtbl.find st.context.types ty).typ with
543+
| Struct l -> (
544+
match (List.nth l i).typ with
545+
| Value typ ->
546+
(if Poly.equal typ st.context.value_type then None else Some typ), st
547+
| Packed _ -> assert false)
548+
| Array _ | Func _ -> assert false)
549+
| RefCast (typ, _) | Br_on_cast (_, _, typ, _) | Br_on_cast_fail (_, typ, _, _) ->
550+
Some (Ref typ), st
551+
| IfExpr (_, _, _, _) | ExternConvertAny _ | AnyConvertExtern _ -> None, st
552+
487553
let tee ?typ x e =
488554
let* e = e in
489555
let* b = is_small_constant e in

compiler/lib-wasm/code_generation.mli

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type context =
3030
; types : Wasm_ast.type_field Code.Var.Hashtbl.t
3131
; mutable closure_envs : Code.Var.t Code.Var.Map.t
3232
(** GC: mapping of recursive functions to their shared environment *)
33+
; closure_types : (Wasm_ast.value_type option list, int) Hashtbl.t
3334
; mutable apply_funs : Code.Var.t Stdlib.IntMap.t
3435
; mutable cps_apply_funs : Code.Var.t Stdlib.IntMap.t
3536
; mutable curry_funs : Code.Var.t Stdlib.IntMap.t
@@ -57,7 +58,7 @@ val instr : Wasm_ast.instruction -> unit t
5758

5859
val seq : unit t -> expression -> expression
5960

60-
val expression_list : ('a -> expression) -> 'a list -> Wasm_ast.expression list t
61+
val expression_list : ('a -> 'b t) -> 'a list -> 'b list t
6162

6263
module Arith : sig
6364
val const : int32 -> expression
@@ -198,3 +199,5 @@ val function_body :
198199
-> param_names:Code.Var.t list
199200
-> body:unit t
200201
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
202+
203+
val variable_type : Code.Var.t -> Wasm_ast.value_type option t

compiler/lib-wasm/gc_target.ml

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,19 @@ module Type = struct
281281
])
282282
})
283283

284-
let env_type ~cps ~arity n =
284+
let make_env_type env_type =
285+
List.map
286+
~f:(fun typ ->
287+
{ W.mut = false
288+
; typ = W.Value (Option.value ~default:(W.Ref { nullable = false; typ = Eq }) typ)
289+
})
290+
env_type
291+
292+
let env_type ~cps ~arity ~env_type_id ~env_type =
285293
register_type
286294
(if cps
287-
then Printf.sprintf "cps_env_%d_%d" arity n
288-
else Printf.sprintf "env_%d_%d" arity n)
295+
then Printf.sprintf "cps_env_%d_%d" arity env_type_id
296+
else Printf.sprintf "env_%d_%d" arity env_type_id)
289297
(fun () ->
290298
let* cl_typ = closure_type ~usage:`Alloc ~cps arity in
291299
let* common = closure_common_fields ~cps in
@@ -309,18 +317,11 @@ module Type = struct
309317
; typ = Value (Ref { nullable = false; typ = Type fun_ty' })
310318
}
311319
])
312-
@ List.init
313-
~f:(fun _ ->
314-
{ W.mut = false
315-
; typ = W.Value (Ref { nullable = false; typ = Eq })
316-
})
317-
~len:n)
320+
@ make_env_type env_type)
318321
})
319322

320-
let rec_env_type ~function_count ~free_variable_count =
321-
register_type
322-
(Printf.sprintf "rec_env_%d_%d" function_count free_variable_count)
323-
(fun () ->
323+
let rec_env_type ~function_count ~env_type_id ~env_type =
324+
register_type (Printf.sprintf "rec_env_%d_%d" function_count env_type_id) (fun () ->
324325
return
325326
{ supertype = None
326327
; final = true
@@ -331,24 +332,20 @@ module Type = struct
331332
{ W.mut = i < function_count
332333
; typ = W.Value (Ref { nullable = false; typ = Eq })
333334
})
334-
~len:(function_count + free_variable_count))
335+
~len:function_count
336+
@ make_env_type env_type)
335337
})
336338

337-
let rec_closure_type ~cps ~arity ~function_count ~free_variable_count =
339+
let rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type =
338340
register_type
339341
(if cps
340-
then
341-
Printf.sprintf
342-
"cps_closure_rec_%d_%d_%d"
343-
arity
344-
function_count
345-
free_variable_count
346-
else Printf.sprintf "closure_rec_%d_%d_%d" arity function_count free_variable_count)
342+
then Printf.sprintf "cps_closure_rec_%d_%d_%d" arity function_count env_type_id
343+
else Printf.sprintf "closure_rec_%d_%d_%d" arity function_count env_type_id)
347344
(fun () ->
348345
let* cl_typ = closure_type ~usage:`Alloc ~cps arity in
349346
let* common = closure_common_fields ~cps in
350347
let* fun_ty' = function_type ~cps arity in
351-
let* env_ty = rec_env_type ~function_count ~free_variable_count in
348+
let* env_ty = rec_env_type ~function_count ~env_type_id ~env_type in
352349
return
353350
{ supertype = Some cl_typ
354351
; final = true
@@ -1099,11 +1096,19 @@ module Closure = struct
10991096
in
11001097
return (W.GlobalGet name)
11011098
else
1102-
let free_variable_count = List.length free_variables in
1099+
let* env_type = expression_list variable_type free_variables in
1100+
let env_type_id =
1101+
try Hashtbl.find context.closure_types env_type
1102+
with Not_found ->
1103+
let id = Hashtbl.length context.closure_types in
1104+
Hashtbl.add context.closure_types env_type id;
1105+
id
1106+
in
1107+
info.id <- Some env_type_id;
11031108
match info.Closure_conversion.functions with
11041109
| [] -> assert false
11051110
| [ _ ] ->
1106-
let* typ = Type.env_type ~cps ~arity free_variable_count in
1111+
let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type in
11071112
let* l = expression_list load free_variables in
11081113
return
11091114
(W.StructNew
@@ -1122,7 +1127,7 @@ module Closure = struct
11221127
@ l ))
11231128
| (g, _) :: _ as functions ->
11241129
let function_count = List.length functions in
1125-
let* env_typ = Type.rec_env_type ~function_count ~free_variable_count in
1130+
let* env_typ = Type.rec_env_type ~function_count ~env_type_id ~env_type in
11261131
let env =
11271132
if Code.Var.equal f g
11281133
then
@@ -1144,7 +1149,7 @@ module Closure = struct
11441149
load env
11451150
in
11461151
let* typ =
1147-
Type.rec_closure_type ~cps ~arity ~function_count ~free_variable_count
1152+
Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type
11481153
in
11491154
let res =
11501155
let* env = env in
@@ -1189,12 +1194,13 @@ module Closure = struct
11891194
let* _ = add_var (Code.Var.fresh ()) in
11901195
return ()
11911196
else
1197+
let env_type_id = Option.value ~default:(-1) info.id in
11921198
let _, arity = List.find ~f:(fun (f', _) -> Code.Var.equal f f') info.functions in
11931199
let arity = if cps then arity - 1 else arity in
11941200
let offset = Memory.env_start arity in
11951201
match info.Closure_conversion.functions with
11961202
| [ _ ] ->
1197-
let* typ = Type.env_type ~cps ~arity free_variable_count in
1203+
let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type:[] in
11981204
let* _ = add_var f in
11991205
let env = Code.Var.fresh_n "env" in
12001206
let* () =
@@ -1214,11 +1220,11 @@ module Closure = struct
12141220
| functions ->
12151221
let function_count = List.length functions in
12161222
let* typ =
1217-
Type.rec_closure_type ~cps ~arity ~function_count ~free_variable_count
1223+
Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type:[]
12181224
in
12191225
let* _ = add_var f in
12201226
let env = Code.Var.fresh_n "env" in
1221-
let* env_typ = Type.rec_env_type ~function_count ~free_variable_count in
1227+
let* env_typ = Type.rec_env_type ~function_count ~env_type_id ~env_type:[] in
12221228
let* () =
12231229
store
12241230
~typ:(W.Ref { nullable = false; typ = Type env_typ })

0 commit comments

Comments
 (0)