Skip to content

Commit 321ead8

Browse files
committed
More precise return types
1 parent 44f5042 commit 321ead8

File tree

8 files changed

+253
-53
lines changed

8 files changed

+253
-53
lines changed

compiler/lib-wasm/code_generation.ml

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
195195
(* I31, struct, array and none have no other subtype *)
196196
| _, (I31 | Type _ | Struct | Array | None_) -> false, st
197197

198+
let rec type_index_lub ty ty' st =
199+
(* Find the LUB efficiently by taking advantage of the fact that
200+
types are defined after their supertypes, making their variables
201+
compare greater. *)
202+
let c = Var.compare ty ty' in
203+
if c > 0
204+
then type_index_lub ty' ty st
205+
else if c = 0
206+
then Some ty
207+
else
208+
let type_field = Var.Hashtbl.find st.context.types ty' in
209+
match type_field.supertype with
210+
| None -> None
211+
| Some ty'' ->
212+
assert (Var.compare ty'' ty' < 0);
213+
type_index_lub ty ty'' st
214+
215+
let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
216+
match ty, ty' with
217+
| (Func | Extern), _ | _, (Func | Extern) -> assert false
218+
| None_, _ -> return ty'
219+
| _, None_ | Struct, Struct | Array, Array -> return ty
220+
| Any, _ | _, Any -> return W.Any
221+
| Eq, _
222+
| _, Eq
223+
| (Struct | Array | Type _), I31
224+
| I31, (Struct | Array | Type _)
225+
| Struct, Array
226+
| Array, Struct -> return (Eq : W.heap_type)
227+
| Struct, Type t | Type t, Struct -> (
228+
fun st ->
229+
let type_field = Var.Hashtbl.find st.context.types t in
230+
match type_field.typ with
231+
| Struct _ -> W.Struct, st
232+
| Array _ | Func _ -> W.Eq, st)
233+
| Array, Type t | Type t, Array -> (
234+
fun st ->
235+
let type_field = Var.Hashtbl.find st.context.types t in
236+
match type_field.typ with
237+
| Array _ -> W.Struct, st
238+
| Struct _ | Func _ -> W.Eq, st)
239+
| Type t, Type t' -> (
240+
let* r = fun st -> type_index_lub t t' st, st in
241+
match r with
242+
| Some t'' -> return (Type t'' : W.heap_type)
243+
| None -> (
244+
fun st ->
245+
let type_field = Var.Hashtbl.find st.context.types t in
246+
let type_field' = Var.Hashtbl.find st.context.types t' in
247+
match type_field.typ, type_field'.typ with
248+
| Struct _, Struct _ -> (Struct : W.heap_type), st
249+
| Array _, Array _ -> W.Array, st
250+
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
251+
| I31, I31 -> return W.I31
252+
253+
let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
254+
match ty, ty' with
255+
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
256+
let* typ = heap_type_lub typ typ' in
257+
return (W.Ref { nullable = nullable || nullable'; typ })
258+
| _ -> assert false
259+
198260
let register_global name ?exported_name ?(constant = false) typ init st =
199261
st.context.other_fields <-
200262
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
@@ -710,7 +772,7 @@ let init_code context = instrs context.init_code
710772

711773
let function_body ~context ~param_names ~body =
712774
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
713-
let (), st = body st in
775+
let res, st = body st in
714776
let local_count, body = st.var_count, List.rev st.instrs in
715777
let local_types = Array.make local_count (Var.fresh (), None) in
716778
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
@@ -728,4 +790,10 @@ let function_body ~context ~param_names ~body =
728790
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
729791
|> Array.to_list
730792
in
731-
locals, body
793+
locals, res, body
794+
795+
let eval ~context e =
796+
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
797+
let r, st = e st in
798+
assert (st.var_count = 0 && List.is_empty st.instrs);
799+
r

compiler/lib-wasm/code_generation.mli

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ val register_type : string -> (unit -> type_def t) -> Wasm_ast.var t
156156

157157
val heap_type_sub : Wasm_ast.heap_type -> Wasm_ast.heap_type -> bool t
158158

159+
val value_type_lub : Wasm_ast.value_type -> Wasm_ast.value_type -> Wasm_ast.value_type t
160+
159161
val register_import :
160162
?allow_tail_call:bool
161163
-> ?import_module:string
@@ -200,8 +202,8 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t
200202
val function_body :
201203
context:context
202204
-> param_names:Code.Var.t list
203-
-> body:unit t
204-
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
205+
-> body:'a t
206+
-> (Wasm_ast.var * Wasm_ast.value_type) list * 'a * Wasm_ast.instruction list
205207

206208
val variable_type : Code.Var.t -> Wasm_ast.value_type option t
207209

@@ -210,3 +212,5 @@ val array_placeholder : Code.Var.t -> expression
210212
val default_value :
211213
Wasm_ast.value_type
212214
-> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t
215+
216+
val eval : context:context -> 'a t -> 'a

compiler/lib-wasm/curry.ml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ module Make (Target : Target_sig.S) = struct
9595
loop m [] f None
9696
in
9797
let param_names = args @ [ f ] in
98-
let locals, body = function_body ~context ~param_names ~body in
98+
let locals, _, body = function_body ~context ~param_names ~body in
9999
W.Function
100100
{ name
101101
; exported_name = None
102-
; typ = None
102+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
103103
; signature = Type.func_type 1
104104
; param_names
105105
; locals
@@ -130,11 +130,11 @@ module Make (Target : Target_sig.S) = struct
130130
push (Closure.curry_allocate ~cps:false ~arity m ~f:name' ~closure:f ~arg:x)
131131
in
132132
let param_names = [ x; f ] in
133-
let locals, body = function_body ~context ~param_names ~body in
133+
let locals, _, body = function_body ~context ~param_names ~body in
134134
W.Function
135135
{ name
136136
; exported_name = None
137-
; typ = None
137+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
138138
; signature = Type.func_type 1
139139
; param_names
140140
; locals
@@ -181,11 +181,11 @@ module Make (Target : Target_sig.S) = struct
181181
loop m [] f None
182182
in
183183
let param_names = args @ [ f ] in
184-
let locals, body = function_body ~context ~param_names ~body in
184+
let locals, _, body = function_body ~context ~param_names ~body in
185185
W.Function
186186
{ name
187187
; exported_name = None
188-
; typ = None
188+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
189189
; signature = Type.func_type 2
190190
; param_names
191191
; locals
@@ -220,11 +220,11 @@ module Make (Target : Target_sig.S) = struct
220220
instr (W.Return (Some c))
221221
in
222222
let param_names = [ x; cont; f ] in
223-
let locals, body = function_body ~context ~param_names ~body in
223+
let locals, _, body = function_body ~context ~param_names ~body in
224224
W.Function
225225
{ name
226226
; exported_name = None
227-
; typ = None
227+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
228228
; signature = Type.func_type 2
229229
; param_names
230230
; locals
@@ -264,7 +264,7 @@ module Make (Target : Target_sig.S) = struct
264264
build_applies (load f) l)
265265
in
266266
let param_names = l @ [ f ] in
267-
let locals, body = function_body ~context ~param_names ~body in
267+
let locals, _, body = function_body ~context ~param_names ~body in
268268
W.Function
269269
{ name
270270
; exported_name = None
@@ -305,7 +305,7 @@ module Make (Target : Target_sig.S) = struct
305305
push (call ~cps:true ~arity:2 (load f) [ x; iterate ]))
306306
in
307307
let param_names = l @ [ f ] in
308-
let locals, body = function_body ~context ~param_names ~body in
308+
let locals, _, body = function_body ~context ~param_names ~body in
309309
W.Function
310310
{ name
311311
; exported_name = None
@@ -340,11 +340,13 @@ module Make (Target : Target_sig.S) = struct
340340
instr (W.Return (Some e))
341341
in
342342
let param_names = l @ [ f ] in
343-
let locals, body = function_body ~context ~param_names ~body in
343+
let locals, _, body = function_body ~context ~param_names ~body in
344344
W.Function
345345
{ name
346346
; exported_name = None
347-
; typ = None
347+
; typ =
348+
Some
349+
(eval ~context (Type.function_type ~cps (if cps then arity - 1 else arity)))
348350
; signature = Type.func_type arity
349351
; param_names
350352
; locals

compiler/lib-wasm/gc_target.ml

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,35 @@ module Type = struct
203203
let primitive_type n =
204204
{ W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] }
205205

206-
let func_type n = primitive_type (n + 1)
207-
208-
let function_type ~cps n =
209-
let n = if cps then n + 1 else n in
210-
register_type (Printf.sprintf "function_%d" n) (fun () ->
211-
return { supertype = None; final = true; typ = W.Func (func_type n) })
206+
let func_type ?(ret = value) n =
207+
{ W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ ret ] }
208+
209+
let rec function_type ~cps ?ret n =
210+
let n' = if cps then n + 1 else n in
211+
let ret_str =
212+
match ret with
213+
| None -> ""
214+
| Some (W.Ref { nullable = false; typ }) -> (
215+
match typ with
216+
| Eq -> "_eq" (*ZZZ remove ret in that case*)
217+
| I31 -> "_i31"
218+
| Struct -> "_struct"
219+
| Array -> "_array"
220+
| None_ -> "_none"
221+
| Type v -> (
222+
match Code.Var.get_name v with
223+
| None -> assert false
224+
| Some name -> "_" ^ name)
225+
| _ -> assert false)
226+
| _ -> assert false
227+
in
228+
register_type (Printf.sprintf "function_%d%s" n' ret_str) (fun () ->
229+
match ret with
230+
| None -> return { supertype = None; final = false; typ = W.Func (func_type n') }
231+
| Some ret ->
232+
let* super = function_type ~cps n in
233+
return
234+
{ supertype = Some super; final = false; typ = W.Func (func_type ~ret n') })
212235

213236
let closure_common_fields ~cps =
214237
let* fun_ty = function_type ~cps 1 in

0 commit comments

Comments
 (0)