diff --git a/CHANGES.md b/CHANGES.md index fc4afa54cf..3c6be5232f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -42,6 +42,7 @@ * Compiler: deadcode elimination of cyclic values (#1978) * Compiler: directly write Wasm binary modules (#2000, #2003) * Compiler: rewrote inlining pass (#1935, #2018, #2027) +* Compiler/wasm: optimize integer operations (#2032) ## Bug fixes * Compiler: fix stack overflow issues with double translation (#1869) diff --git a/compiler/bin-wasm_of_ocaml/compile.ml b/compiler/bin-wasm_of_ocaml/compile.ml index 5149df29d3..d0f39f6c28 100644 --- a/compiler/bin-wasm_of_ocaml/compile.ml +++ b/compiler/bin-wasm_of_ocaml/compile.ml @@ -245,8 +245,8 @@ let generate_prelude ~out_file = @@ fun ch -> let code, uinfo = Parse_bytecode.predefined_exceptions () in let profile = Profile.O1 in - let Driver.{ program; variable_uses; in_cps; deadcode_sentinal; _ } = - Driver.optimize ~profile code + let Driver.{ program; variable_uses; in_cps; deadcode_sentinal; _ }, global_flow_data = + Driver.optimize_for_wasm ~profile code in let context = Generate.start () in let _ = @@ -256,6 +256,7 @@ let generate_prelude ~out_file = ~live_vars:variable_uses ~in_cps ~deadcode_sentinal + ~global_flow_data program in Generate.wasm_output ch ~opt_source_map_file:None ~context; @@ -397,8 +398,9 @@ let run check_debug one; let code = one.code in let standalone = Option.is_none unit_name in - let Driver.{ program; variable_uses; in_cps; deadcode_sentinal; _ } = - Driver.optimize ~profile code + let Driver.{ program; variable_uses; in_cps; deadcode_sentinal; _ }, global_flow_data + = + Driver.optimize_for_wasm ~profile code in let context = Generate.start () in let toplevel_name, generated_js = @@ -408,6 +410,7 @@ let run ~live_vars:variable_uses ~in_cps ~deadcode_sentinal + ~global_flow_data program in if standalone then Generate.add_start_function ~context toplevel_name; diff --git a/compiler/lib-wasm/closure_conversion.ml b/compiler/lib-wasm/closure_conversion.ml index 8f30e2fc90..88a5cb337d 100644 --- a/compiler/lib-wasm/closure_conversion.ml +++ b/compiler/lib-wasm/closure_conversion.ml @@ -22,6 +22,7 @@ open Code type closure = { functions : (Var.t * int) list ; free_variables : Var.t list + ; mutable id : int option } module SCC = Strongly_connected_components.Make (Var) @@ -144,7 +145,8 @@ let rec traverse var_depth closures program pc depth = in List.iter ~f:(fun (f, _) -> - closures := Var.Map.add f { functions; free_variables } !closures) + closures := + Var.Map.add f { functions; free_variables; id = None } !closures) functions; fun_lst) components diff --git a/compiler/lib-wasm/closure_conversion.mli b/compiler/lib-wasm/closure_conversion.mli index 41a5e0642c..f042f1806f 100644 --- a/compiler/lib-wasm/closure_conversion.mli +++ b/compiler/lib-wasm/closure_conversion.mli @@ -19,6 +19,7 @@ type closure = { functions : (Code.Var.t * int) list ; free_variables : Code.Var.t list + ; mutable id : int option } val f : Code.program -> Code.program * closure Code.Var.Map.t diff --git a/compiler/lib-wasm/code_generation.ml b/compiler/lib-wasm/code_generation.ml index 7e7cb6af95..4322356722 100644 --- a/compiler/lib-wasm/code_generation.ml +++ b/compiler/lib-wasm/code_generation.ml @@ -34,6 +34,7 @@ https://github.com/llvm/llvm-project/issues/58438 type constant_global = { init : W.expression option ; constant : bool + ; typ : W.value_type } type context = @@ -46,6 +47,7 @@ type context = ; types : Wasm_ast.type_field Var.Hashtbl.t ; mutable closure_envs : Var.t Var.Map.t (** GC: mapping of recursive functions to their shared environment *) + ; closure_types : (W.value_type option list, int) Hashtbl.t ; mutable apply_funs : Var.t IntMap.t ; mutable cps_apply_funs : Var.t IntMap.t ; mutable curry_funs : Var.t IntMap.t @@ -68,6 +70,7 @@ let make_context ~value_type = ; type_names = String.Hashtbl.create 128 ; types = Var.Hashtbl.create 128 ; closure_envs = Var.Map.empty + ; closure_types = Poly.Hashtbl.create 128 ; apply_funs = IntMap.empty ; cps_apply_funs = IntMap.empty ; curry_funs = IntMap.empty @@ -198,6 +201,7 @@ let register_global name ?exported_name ?(constant = false) typ init st = name { init = (if not typ.mut then Some init else None) ; constant = (not typ.mut) || constant + ; typ = typ.typ } st.context.constant_globals; (), st @@ -413,76 +417,73 @@ let is_small_constant e = | W.GlobalGet name -> global_is_constant name | _ -> return false -let un_op_is_smi op = - match op with - | W.Clz | Ctz | Popcnt | Eqz -> true - | TruncSatF64 _ | ReinterpretF -> false +let load x = + let* x = var x in + match x with + | Local (_, x, _) -> return (W.LocalGet x) + | Expr e -> e -let bin_op_is_smi (op : W.int_bin_op) = - match op with - | W.Add | Sub | Mul | Div _ | Rem _ | And | Or | Xor | Shl | Shr _ | Rotl | Rotr -> - false - | Eq | Ne | Lt _ | Gt _ | Le _ | Ge _ -> true +let rec variable_type x st = + match Var.Map.find_opt x st.vars with + | Some (Local (_, _, typ)) -> typ, st + | Some (Expr e) -> + (let* e = e in + expression_type e) + st + | None -> None, st -let rec is_smi e = +and expression_type (e : W.expression) st = match e with - | W.Const (I32 i) -> Int32.equal (Arith.wrap31 i) i - | UnOp ((I32 op | I64 op), _) -> un_op_is_smi op - | BinOp ((I32 op | I64 op), _, _) -> bin_op_is_smi op - | I31Get (S, _) -> true - | I31Get (U, _) - | Const (I64 _ | F32 _ | F64 _) - | UnOp ((F32 _ | F64 _), _) + | Const _ + | UnOp _ + | BinOp _ | I32WrapI64 _ | I64ExtendI32 _ | F32DemoteF64 _ | F64PromoteF32 _ - | LocalGet _ - | LocalTee _ - | GlobalGet _ | BlockExpr _ | Call _ - | Seq _ - | Pop _ | RefFunc _ | Call_ref _ - | RefI31 _ - | ArrayNew _ - | ArrayNewFixed _ - | ArrayNewData _ + | I31Get _ | ArrayGet _ | ArrayLen _ - | StructNew _ - | StructGet _ - | RefCast _ + | RefTest _ + | RefEq _ | RefNull _ - | Br_on_cast _ - | Br_on_cast_fail _ - | Br_on_null _ | Try _ - | ExternConvertAny _ - | AnyConvertExtern _ -> false - | BinOp ((F32 _ | F64 _), _, _) | RefTest _ | RefEq _ -> true - | IfExpr (_, _, ift, iff) -> is_smi ift && is_smi iff - -let get_i31_value x st = - match st.instrs with - | LocalSet (x', RefI31 e) :: rem when Code.Var.equal x x' && is_smi e -> - let x = Var.fresh () in - let x, st = add_var ~typ:I32 x st in - Some x, { st with instrs = LocalSet (x', RefI31 (LocalTee (x, e))) :: rem } - | Event loc :: LocalSet (x', RefI31 e) :: rem when Code.Var.equal x x' && is_smi e -> - let x = Var.fresh () in - let x, st = add_var ~typ:I32 x st in - ( Some x - , { st with instrs = Event loc :: LocalSet (x', RefI31 (LocalTee (x, e))) :: rem } ) - | _ -> None, st - -let load x = - let* x = var x in - match x with - | Local (_, x, _) -> return (W.LocalGet x) - | Expr e -> e + | Br_on_null _ -> None, st + | LocalGet x | LocalTee (x, _) -> variable_type x st + | GlobalGet x -> + ( (try + let typ = (Var.Map.find x st.context.constant_globals).typ in + if Poly.equal typ st.context.value_type + then None + else + Some + (match typ with + | Ref { typ; nullable = true } -> Ref { typ; nullable = false } + | _ -> typ) + with Not_found -> None) + , st ) + | Seq (_, e') -> expression_type e' st + | Pop typ -> Some typ, st + | RefI31 _ -> Some (Ref { nullable = false; typ = I31 }), st + | ArrayNew (ty, _, _) + | ArrayNewFixed (ty, _) + | ArrayNewData (ty, _, _, _) + | StructNew (ty, _) -> Some (Ref { nullable = false; typ = Type ty }), st + | StructGet (_, ty, i, _) -> ( + match (Var.Hashtbl.find st.context.types ty).typ with + | Struct l -> ( + match (List.nth l i).typ with + | Value typ -> + (if Poly.equal typ st.context.value_type then None else Some typ), st + | Packed _ -> assert false) + | Array _ | Func _ -> assert false) + | RefCast (typ, _) | Br_on_cast (_, _, typ, _) | Br_on_cast_fail (_, typ, _, _) -> + Some (Ref typ), st + | IfExpr (_, _, _, _) | ExternConvertAny _ | AnyConvertExtern _ -> None, st let tee ?typ x e = let* e = e in @@ -499,6 +500,47 @@ let should_make_global x st = Var.Set.mem x st.context.globalized_variables, st let value_type st = st.context.value_type, st +let get_constant x st = Var.Hashtbl.find_opt st.context.constants x, st + +let placeholder_value typ f = + let* c = get_constant typ in + match c with + | None -> + let x = Var.fresh () in + let* () = register_constant typ (W.GlobalGet x) in + let* () = + register_global + ~constant:true + x + { mut = false; typ = Ref { nullable = false; typ = Type typ } } + (f typ) + in + return (W.GlobalGet x) + | Some c -> return c + +let array_placeholder typ = placeholder_value typ (fun typ -> ArrayNewFixed (typ, [])) + +let default_value val_typ st = + match val_typ with + | W.Ref { typ = I31 | Eq | Any; _ } -> (W.RefI31 (Const (I32 0l)), val_typ, None), st + | W.Ref { typ = Type typ; nullable = false } -> ( + match (Var.Hashtbl.find st.context.types typ).typ with + | Array _ -> + (let* placeholder = array_placeholder typ in + return (placeholder, val_typ, None)) + st + | Struct _ | Func _ -> + ( ( W.RefNull (Type typ) + , W.Ref { typ = Type typ; nullable = true } + , Some { W.typ = Type typ; nullable = false } ) + , st )) + | I32 -> (Const (I32 0l), val_typ, None), st + | F32 -> (Const (F32 0.), val_typ, None), st + | I64 -> (Const (I64 0L), val_typ, None), st + | F64 -> (Const (F64 0.), val_typ, None), st + | W.Ref { nullable = true; _ } + | W.Ref { typ = Func | Extern | Struct | Array | None_; _ } -> assert false + let rec store ?(always = false) ?typ x e = let* e = e in match e with @@ -513,23 +555,26 @@ let rec store ?(always = false) ?typ x e = let* b = should_make_global x in if b then - let* typ = - match typ with - | Some typ -> return typ - | None -> value_type - in let* () = let* b = global_is_registered x in if b then return () else - register_global - ~constant:true - x - { mut = true; typ } - (W.RefI31 (Const (I32 0l))) + let* typ = + match typ with + | Some typ -> return typ + | None -> value_type + in + let* default, typ', cast = default_value typ in + let* () = + register_constant + x + (match cast with + | Some typ -> W.RefCast (typ, W.GlobalGet x) + | None -> W.GlobalGet x) + in + register_global ~constant:true x { mut = true; typ = typ' } default in - let* () = register_constant x (W.GlobalGet x) in instr (GlobalSet (x, e)) else let* i = add_var ?typ x in diff --git a/compiler/lib-wasm/code_generation.mli b/compiler/lib-wasm/code_generation.mli index bb72950262..8655450dda 100644 --- a/compiler/lib-wasm/code_generation.mli +++ b/compiler/lib-wasm/code_generation.mli @@ -30,6 +30,7 @@ type context = ; types : Wasm_ast.type_field Code.Var.Hashtbl.t ; mutable closure_envs : Code.Var.t Code.Var.Map.t (** GC: mapping of recursive functions to their shared environment *) + ; closure_types : (Wasm_ast.value_type option list, int) Hashtbl.t ; mutable apply_funs : Code.Var.t Stdlib.IntMap.t ; mutable cps_apply_funs : Code.Var.t Stdlib.IntMap.t ; mutable curry_funs : Code.Var.t Stdlib.IntMap.t @@ -57,7 +58,7 @@ val instr : Wasm_ast.instruction -> unit t val seq : unit t -> expression -> expression -val expression_list : ('a -> expression) -> 'a list -> Wasm_ast.expression list t +val expression_list : ('a -> 'b t) -> 'a list -> 'b list t module Arith : sig val const : int32 -> expression @@ -138,8 +139,6 @@ val define_var : Wasm_ast.var -> expression -> unit t val is_small_constant : Wasm_ast.expression -> bool t -val get_i31_value : Wasm_ast.var -> Wasm_ast.var option t - val event : Parse_info.t -> unit t val no_event : unit t @@ -198,3 +197,11 @@ val function_body : -> param_names:Code.Var.t list -> body:unit t -> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list + +val variable_type : Code.Var.t -> Wasm_ast.value_type option t + +val array_placeholder : Code.Var.t -> expression + +val default_value : + Wasm_ast.value_type + -> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t diff --git a/compiler/lib-wasm/curry.ml b/compiler/lib-wasm/curry.ml index c39dcb6910..b6d5ab0cab 100644 --- a/compiler/lib-wasm/curry.ml +++ b/compiler/lib-wasm/curry.ml @@ -298,6 +298,7 @@ module Make (Target : Target_sig.S) = struct Memory.allocate ~tag:0 ~deadcode_sentinal:(Code.Var.fresh ()) + ~load (List.map ~f:(fun x -> `Var x) (List.tl l)) in let* make_iterator = diff --git a/compiler/lib-wasm/gc_target.ml b/compiler/lib-wasm/gc_target.ml index 112e04fb7e..686b62d096 100644 --- a/compiler/lib-wasm/gc_target.ml +++ b/compiler/lib-wasm/gc_target.ml @@ -281,11 +281,19 @@ module Type = struct ]) }) - let env_type ~cps ~arity n = + let make_env_type env_type = + List.map + ~f:(fun typ -> + { W.mut = false + ; typ = W.Value (Option.value ~default:(W.Ref { nullable = false; typ = Eq }) typ) + }) + env_type + + let env_type ~cps ~arity ~env_type_id ~env_type = register_type (if cps - then Printf.sprintf "cps_env_%d_%d" arity n - else Printf.sprintf "env_%d_%d" arity n) + then Printf.sprintf "cps_env_%d_%d" arity env_type_id + else Printf.sprintf "env_%d_%d" arity env_type_id) (fun () -> let* cl_typ = closure_type ~usage:`Alloc ~cps arity in let* common = closure_common_fields ~cps in @@ -309,18 +317,11 @@ module Type = struct ; typ = Value (Ref { nullable = false; typ = Type fun_ty' }) } ]) - @ List.init - ~f:(fun _ -> - { W.mut = false - ; typ = W.Value (Ref { nullable = false; typ = Eq }) - }) - ~len:n) + @ make_env_type env_type) }) - let rec_env_type ~function_count ~free_variable_count = - register_type - (Printf.sprintf "rec_env_%d_%d" function_count free_variable_count) - (fun () -> + let rec_env_type ~function_count ~env_type_id ~env_type = + register_type (Printf.sprintf "rec_env_%d_%d" function_count env_type_id) (fun () -> return { supertype = None ; final = true @@ -331,24 +332,20 @@ module Type = struct { W.mut = i < function_count ; typ = W.Value (Ref { nullable = false; typ = Eq }) }) - ~len:(function_count + free_variable_count)) + ~len:function_count + @ make_env_type env_type) }) - let rec_closure_type ~cps ~arity ~function_count ~free_variable_count = + let rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type = register_type (if cps - then - Printf.sprintf - "cps_closure_rec_%d_%d_%d" - arity - function_count - free_variable_count - else Printf.sprintf "closure_rec_%d_%d_%d" arity function_count free_variable_count) + then Printf.sprintf "cps_closure_rec_%d_%d_%d" arity function_count env_type_id + else Printf.sprintf "closure_rec_%d_%d_%d" arity function_count env_type_id) (fun () -> let* cl_typ = closure_type ~usage:`Alloc ~cps arity in let* common = closure_common_fields ~cps in let* fun_ty' = function_type ~cps arity in - let* env_ty = rec_env_type ~function_count ~free_variable_count in + let* env_ty = rec_env_type ~function_count ~env_type_id ~env_type in return { supertype = Some cl_typ ; final = true @@ -431,7 +428,7 @@ module Value = struct let dummy_block = let* t = Type.block_type in - return (W.ArrayNewFixed (t, [])) + array_placeholder t let as_block e = let* t = Type.block_type in @@ -446,25 +443,17 @@ module Value = struct let check_is_not_zero i = let* i = i in - match i with - | W.LocalGet x -> ( - let* x_opt = get_i31_value x in - match x_opt with - | Some x' -> return (W.LocalGet x') - | None -> return (W.UnOp (I32 Eqz, RefEq (i, W.RefI31 (Const (I32 0l)))))) - | _ -> return (W.UnOp (I32 Eqz, RefEq (i, W.RefI31 (Const (I32 0l))))) + return (W.UnOp (I32 Eqz, RefEq (i, W.RefI31 (Const (I32 0l))))) let check_is_int i = let* i = i in return (W.RefTest ({ nullable = false; typ = I31 }, i)) - let not i = val_int (Arith.eqz (int_val i)) - - let binop op i i' = val_int (op (int_val i) (int_val i')) + let not i = Arith.eqz i - let lt = binop Arith.( < ) + let lt = Arith.( < ) - let le = binop Arith.( <= ) + let le = Arith.( <= ) let ref_eq i i' = let* i = i in @@ -574,41 +563,41 @@ module Value = struct (let* () = store xv x in let* () = store yv y in return ()) - (val_int (if negate then Arith.eqz n else n)) + (if negate then Arith.eqz n else n) let eq x y = eq_gen ~negate:false x y let neq x y = eq_gen ~negate:true x y - let ult = binop Arith.(ult) + let ult = Arith.ult let is_int i = let* i = i in - val_int (return (W.RefTest ({ nullable = false; typ = I31 }, i))) + return (W.RefTest ({ nullable = false; typ = I31 }, i)) - let int_add = binop Arith.( + ) + let int_add = Arith.( + ) - let int_sub = binop Arith.( - ) + let int_sub = Arith.( - ) - let int_mul = binop Arith.( * ) + let int_mul = Arith.( * ) - let int_div = binop Arith.( / ) + let int_div = Arith.( / ) - let int_mod = binop Arith.( mod ) + let int_mod = Arith.( mod ) - let int_neg i = val_int Arith.(const 0l - int_val i) + let int_neg i = Arith.(const 0l - i) - let int_or = binop Arith.( lor ) + let int_or = Arith.( lor ) - let int_and = binop Arith.( land ) + let int_and = Arith.( land ) - let int_xor = binop Arith.( lxor ) + let int_xor = Arith.( lxor ) - let int_lsl = binop Arith.( lsl ) + let int_lsl = Arith.( lsl ) - let int_lsr i i' = val_int Arith.((int_val i land const 0x7fffffffl) lsr int_val i') + let int_lsr i i' = Arith.((i land const 0x7fffffffl) lsr i') - let int_asr = binop Arith.( asr ) + let int_asr = Arith.( asr ) end module Memory = struct @@ -660,7 +649,7 @@ module Memory = struct let* ty = Type.float_type in wasm_struct_get ty (wasm_cast ty e) 0 - let allocate ~tag ~deadcode_sentinal l = + let allocate ~tag ~deadcode_sentinal ~load l = if tag = 254 then let* l = @@ -731,15 +720,14 @@ module Memory = struct let* e = float_array_length (load a) in instr (W.Push e)) - let array_get e e' = wasm_array_get e Arith.(Value.int_val e' + const 1l) + let array_get e e' = wasm_array_get e Arith.(e' + const 1l) - let array_set e e' e'' = wasm_array_set e Arith.(Value.int_val e' + const 1l) e'' + let array_set e e' e'' = wasm_array_set e Arith.(e' + const 1l) e'' - let float_array_get e e' = - box_float (wasm_array_get ~ty:Type.float_array_type e (Value.int_val e')) + let float_array_get e e' = box_float (wasm_array_get ~ty:Type.float_array_type e e') let float_array_set e e' e'' = - wasm_array_set ~ty:Type.float_array_type e (Value.int_val e') (unbox_float e'') + wasm_array_set ~ty:Type.float_array_type e e' (unbox_float e'') let gen_array_get e e' = let a = Code.Var.fresh_n "a" in @@ -747,7 +735,7 @@ module Memory = struct block_expr { params = []; result = [ Type.value ] } (let* () = store a e in - let* () = store ~typ:I32 i (Value.int_val e') in + let* () = store ~typ:I32 i e' in let* () = drop (block_expr @@ -774,7 +762,7 @@ module Memory = struct let i = Code.Var.fresh_n "i" in let v = Code.Var.fresh_n "v" in let* () = store a e in - let* () = store ~typ:I32 i (Value.int_val e') in + let* () = store ~typ:I32 i e' in let* () = store v e'' in block { params = []; result = [] } @@ -804,11 +792,9 @@ module Memory = struct let* e = wasm_cast ty e in return (W.ArrayLen e) - let bytes_get e e' = - Value.val_int (wasm_array_get ~ty:Type.string_type e (Value.int_val e')) + let bytes_get e e' = wasm_array_get ~ty:Type.string_type e e' - let bytes_set e e' e'' = - wasm_array_set ~ty:Type.string_type e (Value.int_val e') (Value.int_val e'') + let bytes_set e e' e'' = wasm_array_set ~ty:Type.string_type e e' e'' let field e idx = wasm_array_get e (Arith.const (Int32.of_int (idx + 1))) @@ -1035,23 +1021,26 @@ module Constant = struct return (Const, e) let translate c = - let* const, c = translate_rec c in - match const with - | Const -> - let* b = is_small_constant c in - if b then return c else store_in_global c - | Const_named name -> store_in_global ~name c - | Mutated -> - let name = Code.Var.fresh_n "const" in - let* () = - register_global - ~constant:true - name - { mut = true; typ = Type.value } - (W.RefI31 (Const (I32 0l))) - in - let* () = register_init_code (instr (W.GlobalSet (name, c))) in - return (W.GlobalGet name) + match c with + | Code.Int i -> return (W.Const (I32 (Targetint.to_int32 i))) + | _ -> ( + let* const, c = translate_rec c in + match const with + | Const -> + let* b = is_small_constant c in + if b then return c else store_in_global c + | Const_named name -> store_in_global ~name c + | Mutated -> + let name = Code.Var.fresh_n "const" in + let* () = + register_global + ~constant:true + name + { mut = true; typ = Type.value } + (W.RefI31 (Const (I32 0l))) + in + let* () = register_init_code (instr (W.GlobalSet (name, c))) in + return (W.GlobalGet name)) end module Closure = struct @@ -1099,11 +1088,19 @@ module Closure = struct in return (W.GlobalGet name) else - let free_variable_count = List.length free_variables in + let* env_type = expression_list variable_type free_variables in + let env_type_id = + try Hashtbl.find context.closure_types env_type + with Not_found -> + let id = Hashtbl.length context.closure_types in + Hashtbl.add context.closure_types env_type id; + id + in + info.id <- Some env_type_id; match info.Closure_conversion.functions with | [] -> assert false | [ _ ] -> - let* typ = Type.env_type ~cps ~arity free_variable_count in + let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type in let* l = expression_list load free_variables in return (W.StructNew @@ -1122,7 +1119,7 @@ module Closure = struct @ l )) | (g, _) :: _ as functions -> let function_count = List.length functions in - let* env_typ = Type.rec_env_type ~function_count ~free_variable_count in + let* env_typ = Type.rec_env_type ~function_count ~env_type_id ~env_type in let env = if Code.Var.equal f g then @@ -1144,7 +1141,7 @@ module Closure = struct load env in let* typ = - Type.rec_closure_type ~cps ~arity ~function_count ~free_variable_count + Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type in let res = let* env = env in @@ -1189,12 +1186,13 @@ module Closure = struct let* _ = add_var (Code.Var.fresh ()) in return () else + let env_type_id = Option.value ~default:(-1) info.id in let _, arity = List.find ~f:(fun (f', _) -> Code.Var.equal f f') info.functions in let arity = if cps then arity - 1 else arity in let offset = Memory.env_start arity in match info.Closure_conversion.functions with | [ _ ] -> - let* typ = Type.env_type ~cps ~arity free_variable_count in + let* typ = Type.env_type ~cps ~arity ~env_type_id ~env_type:[] in let* _ = add_var f in let env = Code.Var.fresh_n "env" in let* () = @@ -1214,11 +1212,11 @@ module Closure = struct | functions -> let function_count = List.length functions in let* typ = - Type.rec_closure_type ~cps ~arity ~function_count ~free_variable_count + Type.rec_closure_type ~cps ~arity ~function_count ~env_type_id ~env_type:[] in let* _ = add_var f in let env = Code.Var.fresh_n "env" in - let* env_typ = Type.rec_env_type ~function_count ~free_variable_count in + let* env_typ = Type.rec_env_type ~function_count ~env_type_id ~env_type:[] in let* () = store ~typ:(W.Ref { nullable = false; typ = Type env_typ }) diff --git a/compiler/lib-wasm/generate.ml b/compiler/lib-wasm/generate.ml index 505c4d2a2b..d65a71d04d 100644 --- a/compiler/lib-wasm/generate.ml +++ b/compiler/lib-wasm/generate.ml @@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct { live : int array ; in_cps : Effects.in_cps ; deadcode_sentinal : Var.t + ; types : Typing.typ Var.Tbl.t ; blocks : block Addr.Map.t ; closures : Closure_conversion.closure Var.Map.t ; global_context : Code_generation.context @@ -144,7 +145,7 @@ module Generate (Target : Target_sig.S) = struct let float_comparison op f g = let* f = Memory.unbox_float f in let* g = Memory.unbox_float g in - Value.val_int (return (W.BinOp (F64 op, f, g))) + return (W.BinOp (F64 op, f, g)) let int32_bin_op op f g = let* f = Memory.unbox_int32 f in @@ -153,7 +154,7 @@ module Generate (Target : Target_sig.S) = struct let int32_shift_op op f g = let* f = Memory.unbox_int32 f in - let* g = Value.int_val g in + let* g = g in Memory.box_int32 (return (W.BinOp (I32 op, f, g))) let int64_bin_op op f g = @@ -163,7 +164,7 @@ module Generate (Target : Target_sig.S) = struct let int64_shift_op op f g = let* f = Memory.unbox_int64 f in - let* g = Value.int_val g in + let* g = g in Memory.box_int64 (return (W.BinOp (I64 op, f, I64ExtendI32 (S, g)))) let nativeint_bin_op op f g = @@ -173,19 +174,62 @@ module Generate (Target : Target_sig.S) = struct let nativeint_shift_op op f g = let* f = Memory.unbox_nativeint f in - let* g = Value.int_val g in + let* g = g in Memory.box_nativeint (return (W.BinOp (I32 op, f, g))) - let transl_prim_arg x = - match x with - | Pv x -> load x - | Pc c -> Constant.translate c + let get_var_type ctx x = Var.Tbl.get ctx.types x + + let get_type ctx p = + match p with + | Pv x -> get_var_type ctx x + | Pc c -> Typing.constant_type c + + let convert ~(from : Typing.typ) ~(into : Typing.typ) e = + match from, into with + | Int Unnormalized, Int Normalized -> Arith.((e lsl const 1l) asr const 1l) + | Int (Normalized | Unnormalized), Int (Normalized | Unnormalized) -> e + | _, Int (Normalized | Unnormalized) -> Value.int_val e + | Int (Unnormalized | Normalized), _ -> Value.val_int e + | _ -> e + + let load_and_box ctx x = convert ~from:(get_var_type ctx x) ~into:Top (load x) + + let transl_prim_arg ctx ?(typ = Typing.Top) x = + convert + ~from:(get_type ctx x) + ~into:typ + (match x with + | Pv x -> load x + | Pc c -> Constant.translate c) + + let translate_int_comparison ctx op x y = + match get_type ctx x, get_type ctx y with + | Int Unnormalized, Int Unnormalized + | Int Normalized, Int Unnormalized + | Int Unnormalized, Int Normalized -> + op + Arith.(transl_prim_arg ctx ~typ:(Int Unnormalized) x lsl const 1l) + Arith.(transl_prim_arg ctx ~typ:(Int Unnormalized) y lsl const 1l) + | _ -> + op + (transl_prim_arg ctx ~typ:(Int Normalized) x) + (transl_prim_arg ctx ~typ:(Int Normalized) y) + + let translate_int_equality ctx op op' x y = + match get_type ctx x, get_type ctx y with + | (Int Normalized as typ), Int Normalized -> + op (transl_prim_arg ctx ~typ x) (transl_prim_arg ctx ~typ y) + | Int (Normalized | Unnormalized), Int (Normalized | Unnormalized) -> + op + Arith.(transl_prim_arg ctx ~typ:(Int Unnormalized) x lsl const 1l) + Arith.(transl_prim_arg ctx ~typ:(Int Unnormalized) y lsl const 1l) + | _ -> op' (transl_prim_arg ctx ~typ:Top x) (transl_prim_arg ctx ~typ:Top y) let internal_primitives = let h = String.Hashtbl.create 128 in List.iter ~f:(fun (nm, k, f) -> - String.Hashtbl.add h nm (k, fun _ _ transl_prim_arg l -> f transl_prim_arg l)) + String.Hashtbl.add h nm (k, fun ctx _ l -> f (fun x -> transl_prim_arg ctx x) l)) internal_primitives; h @@ -199,108 +243,212 @@ module Generate (Target : Target_sig.S) = struct expected (List.length l)) - let register_un_prim name k f = - register_prim name k (fun _ _ transl_prim_arg l -> + let register_un_prim name k ?typ f = + register_prim name k (fun ctx _ l -> match l with - | [ x ] -> f (transl_prim_arg x) + | [ x ] -> f (transl_prim_arg ctx ?typ x) | l -> invalid_arity name l ~expected:1) - let register_bin_prim name k f = - register_prim name k (fun _ _ transl_prim_arg l -> + let register_bin_prim name k ?tx ?ty f = + register_prim name k (fun ctx _ l -> match l with - | [ x; y ] -> f (transl_prim_arg x) (transl_prim_arg y) + | [ x; y ] -> f (transl_prim_arg ctx ?typ:tx x) (transl_prim_arg ctx ?typ:ty y) | _ -> invalid_arity name l ~expected:2) - let register_bin_prim_ctx name f = - register_prim name `Mutator (fun _ context transl_prim_arg l -> + let register_bin_prim_ctx name ?tx ?ty f = + register_prim name `Mutator (fun ctx context l -> match l with - | [ x; y ] -> f context (transl_prim_arg x) (transl_prim_arg y) + | [ x; y ] -> + f context (transl_prim_arg ctx ?typ:tx x) (transl_prim_arg ctx ?typ:ty y) | _ -> invalid_arity name l ~expected:2) - let register_tern_prim name f = - register_prim name `Mutator (fun _ _ transl_prim_arg l -> + let register_tern_prim name ?ty ?tz f = + register_prim name `Mutator (fun ctx _ l -> match l with - | [ x; y; z ] -> f (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z) + | [ x; y; z ] -> + f + (transl_prim_arg ctx x) + (transl_prim_arg ctx ?typ:ty y) + (transl_prim_arg ctx ?typ:tz z) | _ -> invalid_arity name l ~expected:3) - let register_tern_prim_ctx name f = - register_prim name `Mutator (fun _ context transl_prim_arg l -> + let register_tern_prim_ctx name ?ty ?tz f = + register_prim name `Mutator (fun ctx context l -> match l with | [ x; y; z ] -> - f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z) + f + context + (transl_prim_arg ctx x) + (transl_prim_arg ctx ?typ:ty y) + (transl_prim_arg ctx ?typ:tz z) | _ -> invalid_arity name l ~expected:3) let () = - register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get; - register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get; - register_tern_prim "caml_array_unsafe_set" (fun x y z -> + register_bin_prim + "caml_array_unsafe_get" + `Mutable + ~ty:(Int Normalized) + Memory.gen_array_get; + register_bin_prim + "caml_floatarray_unsafe_get" + `Mutable + ~ty:(Int Normalized) + Memory.float_array_get; + register_tern_prim "caml_array_unsafe_set" ~ty:(Int Normalized) (fun x y z -> seq (Memory.gen_array_set x y z) Value.unit); - register_tern_prim "caml_array_unsafe_set_addr" (fun x y z -> + register_tern_prim "caml_array_unsafe_set_addr" ~ty:(Int Normalized) (fun x y z -> seq (Memory.array_set x y z) Value.unit); - register_tern_prim "caml_floatarray_unsafe_set" (fun x y z -> + register_tern_prim "caml_floatarray_unsafe_set" ~ty:(Int Normalized) (fun x y z -> seq (Memory.float_array_set x y z) Value.unit); - register_bin_prim "caml_string_unsafe_get" `Pure Memory.bytes_get; - register_bin_prim "caml_bytes_unsafe_get" `Mutable Memory.bytes_get; - register_tern_prim "caml_string_unsafe_set" (fun x y z -> - seq (Memory.bytes_set x y z) Value.unit); - register_tern_prim "caml_bytes_unsafe_set" (fun x y z -> - seq (Memory.bytes_set x y z) Value.unit); + register_bin_prim "caml_string_unsafe_get" `Pure ~ty:(Int Normalized) Memory.bytes_get; + register_bin_prim + "caml_bytes_unsafe_get" + `Mutable + ~ty:(Int Normalized) + Memory.bytes_get; + register_tern_prim + "caml_string_unsafe_set" + ~ty:(Int Normalized) + ~tz:(Int Unnormalized) + (fun x y z -> seq (Memory.bytes_set x y z) Value.unit); + register_tern_prim + "caml_bytes_unsafe_set" + ~ty:(Int Normalized) + ~tz:(Int Unnormalized) + (fun x y z -> seq (Memory.bytes_set x y z) Value.unit); let bytes_get context x y = seq - (let* cond = Arith.uge (Value.int_val y) (Memory.bytes_length x) in + (let* cond = Arith.uge y (Memory.bytes_length x) in instr (W.Br_if (label_index context bound_error_pc, cond))) (Memory.bytes_get x y) in - register_bin_prim_ctx "caml_string_get" bytes_get; - register_bin_prim_ctx "caml_bytes_get" bytes_get; + register_bin_prim_ctx "caml_string_get" ~ty:(Int Normalized) bytes_get; + register_bin_prim_ctx "caml_bytes_get" ~ty:(Int Normalized) bytes_get; let bytes_set context x y z = seq - (let* cond = Arith.uge (Value.int_val y) (Memory.bytes_length x) in + (let* cond = Arith.uge y (Memory.bytes_length x) in let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in Memory.bytes_set x y z) Value.unit in - register_tern_prim_ctx "caml_string_set" bytes_set; - register_tern_prim_ctx "caml_bytes_set" bytes_set; - register_un_prim "caml_ml_string_length" `Pure (fun x -> - Value.val_int (Memory.bytes_length x)); - register_un_prim "caml_ml_bytes_length" `Pure (fun x -> - Value.val_int (Memory.bytes_length x)); - register_bin_prim "%int_add" `Pure Value.int_add; - register_bin_prim "%int_sub" `Pure Value.int_sub; - register_bin_prim "%int_mul" `Pure Value.int_mul; - register_bin_prim "%direct_int_mul" `Pure Value.int_mul; - register_bin_prim "%direct_int_div" `Pure Value.int_div; - register_bin_prim_ctx "%int_div" (fun context x y -> + register_tern_prim_ctx + "caml_string_set" + ~ty:(Int Normalized) + ~tz:(Int Unnormalized) + bytes_set; + register_tern_prim_ctx + "caml_bytes_set" + ~ty:(Int Normalized) + ~tz:(Int Unnormalized) + bytes_set; + register_un_prim "caml_ml_string_length" `Pure (fun x -> Memory.bytes_length x); + register_un_prim "caml_ml_bytes_length" `Pure (fun x -> Memory.bytes_length x); + register_bin_prim + "%int_add" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_add; + register_bin_prim + "%int_sub" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_sub; + register_bin_prim + "%int_mul" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_mul; + register_bin_prim + "%direct_int_mul" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_mul; + register_bin_prim + "%direct_int_div" + `Pure + ~tx:(Int Normalized) + ~ty:(Int Normalized) + Value.int_div; + register_bin_prim_ctx + "%int_div" + ~tx:(Int Normalized) + ~ty:(Int Normalized) + (fun context x y -> seq - (let* cond = Arith.eqz (Value.int_val y) in + (let* cond = Arith.eqz y in instr (W.Br_if (label_index context zero_divide_pc, cond))) (Value.int_div x y)); - register_bin_prim "%direct_int_mod" `Pure Value.int_mod; - register_bin_prim_ctx "%int_mod" (fun context x y -> + register_bin_prim + "%direct_int_mod" + `Pure + ~tx:(Int Normalized) + ~ty:(Int Normalized) + Value.int_mod; + register_bin_prim_ctx + "%int_mod" + ~tx:(Int Normalized) + ~ty:(Int Normalized) + (fun context x y -> seq - (let* cond = Arith.eqz (Value.int_val y) in + (let* cond = Arith.eqz y in instr (W.Br_if (label_index context zero_divide_pc, cond))) (Value.int_mod x y)); - register_un_prim "%int_neg" `Pure Value.int_neg; - register_bin_prim "%int_or" `Pure Value.int_or; - register_bin_prim "%int_and" `Pure Value.int_and; - register_bin_prim "%int_xor" `Pure Value.int_xor; - register_bin_prim "%int_lsl" `Pure Value.int_lsl; - register_bin_prim "%int_lsr" `Pure Value.int_lsr; - register_bin_prim "%int_asr" `Pure Value.int_asr; + register_un_prim "%int_neg" `Pure ~typ:(Int Unnormalized) Value.int_neg; + register_bin_prim + "%int_or" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_or; + register_bin_prim + "%int_and" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_and; + register_bin_prim + "%int_xor" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_xor; + register_bin_prim + "%int_lsl" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_lsl; + register_bin_prim + "%int_lsr" + `Pure + ~tx:(Int Unnormalized) + ~ty:(Int Unnormalized) + Value.int_lsr; + register_bin_prim + "%int_asr" + `Pure + ~tx:(Int Normalized) + ~ty:(Int Unnormalized) + Value.int_asr; register_un_prim "%direct_obj_tag" `Pure Memory.tag; - register_bin_prim_ctx "caml_check_bound" (fun context x y -> + register_bin_prim_ctx "caml_check_bound" ~ty:(Int Normalized) (fun context x y -> seq - (let* cond = Arith.uge (Value.int_val y) (Memory.array_length x) in + (let* cond = Arith.uge y (Memory.array_length x) in instr (W.Br_if (label_index context bound_error_pc, cond))) x); - register_bin_prim_ctx "caml_check_bound_gen" (fun context x y -> + register_bin_prim_ctx "caml_check_bound_gen" ~ty:(Int Normalized) (fun context x y -> seq - (let* cond = Arith.uge (Value.int_val y) (Memory.gen_array_length x) in + (let* cond = Arith.uge y (Memory.gen_array_length x) in instr (W.Br_if (label_index context bound_error_pc, cond))) x); - register_bin_prim_ctx "caml_check_bound_float" (fun context x y -> + register_bin_prim_ctx + "caml_check_bound_float" + ~ty:(Int Normalized) + (fun context x y -> seq (let a = Code.Var.fresh () in let* () = store a x in @@ -309,7 +457,7 @@ module Generate (Target : Target_sig.S) = struct empty array, and the bound check should fail. *) let* cond = Arith.eqz (Memory.check_is_float_array (load a)) in let* () = instr (W.Br_if (label, cond)) in - let* cond = Arith.uge (Value.int_val y) (Memory.float_array_length (load a)) in + let* cond = Arith.uge y (Memory.float_array_length (load a)) in instr (W.Br_if (label, cond))) x); register_bin_prim "caml_add_float" `Pure (fun f g -> float_bin_op Add f g); @@ -320,7 +468,7 @@ module Generate (Target : Target_sig.S) = struct register_un_prim "caml_signbit_float" `Pure (fun f -> let* f = Memory.unbox_float f in let sign = W.BinOp (F64 CopySign, Const (F64 1.), f) in - Value.val_int (return (W.BinOp (F64 Lt, sign, Const (F64 0.))))); + return (W.BinOp (F64 Lt, sign, Const (F64 0.)))); register_un_prim "caml_neg_float" `Pure (fun f -> float_un_op Neg f); register_un_prim "caml_abs_float" `Pure (fun f -> float_un_op Abs f); register_un_prim "caml_ceil_float" `Pure (fun f -> float_un_op Ceil f); @@ -336,9 +484,9 @@ module Generate (Target : Target_sig.S) = struct register_bin_prim "caml_lt_float" `Pure (fun f g -> float_comparison Lt f g); register_un_prim "caml_int_of_float" `Pure (fun f -> let* f = Memory.unbox_float f in - Value.val_int (return (W.UnOp (I32 (TruncSatF64 S), f)))); - register_un_prim "caml_float_of_int" `Pure (fun n -> - let* n = Value.int_val n in + return (W.UnOp (I32 (TruncSatF64 S), f))); + register_un_prim "caml_float_of_int" `Pure ~typ:(Int Normalized) (fun n -> + let* n = n in Memory.box_float (return (W.UnOp (F64 (Convert (`I32, S)), n)))); register_un_prim "caml_cos_float" `Pure (fun f -> float_un_op' Math.cos f); register_un_prim "caml_sin_float" `Pure (fun f -> float_un_op' Math.sin f); @@ -422,15 +570,18 @@ module Generate (Target : Target_sig.S) = struct (let* i = Memory.unbox_int32 i in let* j = load j' in Memory.box_int32 (return (W.BinOp (I32 (Rem S), i, j))))); - register_bin_prim "caml_int32_shift_left" `Pure (fun i j -> int32_shift_op Shl i j); - register_bin_prim "caml_int32_shift_right" `Pure (fun i j -> + register_bin_prim "caml_int32_shift_left" `Pure ~ty:(Int Unnormalized) (fun i j -> + int32_shift_op Shl i j); + register_bin_prim "caml_int32_shift_right" `Pure ~ty:(Int Unnormalized) (fun i j -> int32_shift_op (Shr S) i j); - register_bin_prim "caml_int32_shift_right_unsigned" `Pure (fun i j -> - int32_shift_op (Shr U) i j); - register_un_prim "caml_int32_to_int" `Pure (fun i -> - Value.val_int (Memory.unbox_int32 i)); - register_un_prim "caml_int32_of_int" `Pure (fun i -> - Memory.box_int32 (Value.int_val i)); + register_bin_prim + "caml_int32_shift_right_unsigned" + `Pure + ~ty:(Int Unnormalized) + (fun i j -> int32_shift_op (Shr U) i j); + register_un_prim "caml_int32_to_int" `Pure (fun i -> Memory.unbox_int32 i); + register_un_prim "caml_int32_of_int" `Pure ~typ:(Int Normalized) (fun i -> + Memory.box_int32 i); register_un_prim "caml_nativeint_of_int32" `Pure (fun i -> Memory.box_nativeint (Memory.unbox_int32 i)); register_un_prim "caml_nativeint_to_int32" `Pure (fun i -> @@ -493,16 +644,20 @@ module Generate (Target : Target_sig.S) = struct (let* i = Memory.unbox_int64 i in let* j = load j' in Memory.box_int64 (return (W.BinOp (I64 (Rem S), i, j))))); - register_bin_prim "caml_int64_shift_left" `Pure (fun i j -> int64_shift_op Shl i j); - register_bin_prim "caml_int64_shift_right" `Pure (fun i j -> + register_bin_prim "caml_int64_shift_left" `Pure ~ty:(Int Unnormalized) (fun i j -> + int64_shift_op Shl i j); + register_bin_prim "caml_int64_shift_right" `Pure ~ty:(Int Unnormalized) (fun i j -> int64_shift_op (Shr S) i j); - register_bin_prim "caml_int64_shift_right_unsigned" `Pure (fun i j -> - int64_shift_op (Shr U) i j); + register_bin_prim + "caml_int64_shift_right_unsigned" + ~ty:(Int Unnormalized) + `Pure + (fun i j -> int64_shift_op (Shr U) i j); register_un_prim "caml_int64_to_int" `Pure (fun i -> let* i = Memory.unbox_int64 i in - Value.val_int (return (W.I32WrapI64 i))); - register_un_prim "caml_int64_of_int" `Pure (fun i -> - let* i = Value.int_val i in + return (W.I32WrapI64 i)); + register_un_prim "caml_int64_of_int" `Pure ~typ:(Int Normalized) (fun i -> + let* i = i in Memory.box_int64 (return (match i with @@ -578,31 +733,38 @@ module Generate (Target : Target_sig.S) = struct (let* i = Memory.unbox_nativeint i in let* j = load j' in Memory.box_nativeint (return (W.BinOp (I32 (Rem S), i, j))))); - register_bin_prim "caml_nativeint_shift_left" `Pure (fun i j -> + register_bin_prim "caml_nativeint_shift_left" `Pure ~ty:(Int Unnormalized) (fun i j -> nativeint_shift_op Shl i j); - register_bin_prim "caml_nativeint_shift_right" `Pure (fun i j -> - nativeint_shift_op (Shr S) i j); - register_bin_prim "caml_nativeint_shift_right_unsigned" `Pure (fun i j -> - nativeint_shift_op (Shr U) i j); - register_un_prim "caml_nativeint_to_int" `Pure (fun i -> - Value.val_int (Memory.unbox_nativeint i)); - register_un_prim "caml_nativeint_of_int" `Pure (fun i -> - Memory.box_nativeint (Value.int_val i)); - register_bin_prim "caml_int_compare" `Pure (fun i j -> - Value.val_int - Arith.( - (Value.int_val j < Value.int_val i) - (Value.int_val i < Value.int_val j))); - register_prim "%js_array" `Pure (fun ctx _ transl_prim_arg l -> + register_bin_prim + "caml_nativeint_shift_right" + `Pure + ~ty:(Int Unnormalized) + (fun i j -> nativeint_shift_op (Shr S) i j); + register_bin_prim + "caml_nativeint_shift_right_unsigned" + `Pure + ~ty:(Int Unnormalized) + (fun i j -> nativeint_shift_op (Shr U) i j); + register_un_prim "caml_nativeint_to_int" `Pure (fun i -> Memory.unbox_nativeint i); + register_un_prim "caml_nativeint_of_int" `Pure ~typ:(Int Normalized) (fun i -> + Memory.box_nativeint i); + register_bin_prim + "caml_int_compare" + `Pure + ~tx:(Int Normalized) + ~ty:(Int Normalized) + (fun i j -> Arith.((j < i) - (i < j))); + register_prim "%js_array" `Pure (fun ctx _ l -> let* l = List.fold_right ~f:(fun x acc -> - let* x = transl_prim_arg x in + let* x = transl_prim_arg ctx x in let* acc = acc in return (`Expr x :: acc)) l ~init:(return []) in - Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l) + Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l) let rec translate_expr ctx context x e = match e with @@ -633,7 +795,7 @@ module Generate (Target : Target_sig.S) = struct return (W.Call (g, List.rev (cl :: acc))) | _ -> return (W.Call_ref (ty, funct, List.rev (closure :: acc)))) | x :: r -> - let* x = load x in + let* x = load_and_box ctx x in loop (x :: acc) r in loop [] args @@ -641,18 +803,19 @@ module Generate (Target : Target_sig.S) = struct let* apply = need_apply_fun ~cps:(Var.Set.mem x ctx.in_cps) ~arity:(List.length args) in - let* args = expression_list load args in + let* args = expression_list (fun x -> load_and_box ctx x) args in let* closure = load f in return (W.Call (apply, args @ [ closure ])) | Block (tag, a, _, _) -> Memory.allocate ~deadcode_sentinal:ctx.deadcode_sentinal ~tag + ~load:(fun x -> load_and_box ctx x) (List.map ~f:(fun x -> `Var x) (Array.to_list a)) - | Field (x, n, Non_float) -> Memory.field (load x) n + | Field (x, n, Non_float) -> Memory.field (load_and_box ctx x) n | Field (x, n, Float) -> Memory.float_array_get - (load x) + (load_and_box ctx x) (Constant.translate (Int (Targetint.of_int_warning_on_overflow n))) | Closure _ -> Closure.translate @@ -686,7 +849,7 @@ module Generate (Target : Target_sig.S) = struct in return (W.GlobalGet x) | Prim (Extern "caml_set_global", [ Pc (String name); v ]) -> - let v = transl_prim_arg v in + let v = transl_prim_arg ctx v in let x = Var.fresh_n name in let* () = let* typ = Value.block_type in @@ -697,17 +860,22 @@ module Generate (Target : Target_sig.S) = struct (let* v = Value.as_block v in instr (W.GlobalSet (x, v))) Value.unit + | Prim (Not, [ x ]) -> Value.not (transl_prim_arg ctx ~typ:(Int Unnormalized) x) + | Prim (Lt, [ x; y ]) -> translate_int_comparison ctx Arith.( < ) x y + | Prim (Le, [ x; y ]) -> translate_int_comparison ctx Arith.( <= ) x y + | Prim (Ult, [ x; y ]) -> translate_int_comparison ctx Arith.ult x y + | Prim (Eq, [ x; y ]) -> translate_int_equality ctx Arith.( = ) Value.eq x y + | Prim (Neq, [ x; y ]) -> translate_int_equality ctx Arith.( <> ) Value.neq x y + | Prim (Array_get, [ x; y ]) -> + Memory.array_get + (transl_prim_arg ctx x) + (transl_prim_arg ctx ~typ:(Int Normalized) y) | Prim (p, l) -> ( match p with | Extern name when String.Hashtbl.mem internal_primitives name -> - snd - (String.Hashtbl.find internal_primitives name) - ctx - context - transl_prim_arg - l + snd (String.Hashtbl.find internal_primitives name) ctx context l | _ -> ( - let l = List.map ~f:transl_prim_arg l in + let l = List.map ~f:(fun x -> transl_prim_arg ctx x) l in match p, l with | Extern name, l -> ( try @@ -736,29 +904,31 @@ module Generate (Target : Target_sig.S) = struct loop (x :: acc) r in loop [] l) - | Not, [ x ] -> Value.not x - | Lt, [ x; y ] -> Value.lt x y - | Le, [ x; y ] -> Value.le x y - | Eq, [ x; y ] -> Value.eq x y - | Neq, [ x; y ] -> Value.neq x y - | Ult, [ x; y ] -> Value.ult x y - | Array_get, [ x; y ] -> Memory.array_get x y | IsInt, [ x ] -> Value.is_int x - | Vectlength, [ x ] -> Value.val_int (Memory.gen_array_length x) + | Vectlength, [ x ] -> Memory.gen_array_length x | (Not | Lt | Le | Eq | Neq | Ult | Array_get | IsInt | Vectlength), _ -> assert false)) and translate_instr ctx context i = match i with - | Assign (x, y) -> assign x (load y) + | Assign (x, y) -> + assign x (convert ~from:(get_var_type ctx y) ~into:(get_var_type ctx x) (load y)) | Let (x, e) -> if ctx.live.(Var.idx x) = 0 then drop (translate_expr ctx context x e) - else store x (translate_expr ctx context x e) - | Set_field (x, n, Non_float, y) -> Memory.set_field (load x) n (load y) + else + store + ?typ: + (match get_var_type ctx x with + | Int (Normalized | Unnormalized) -> Some I32 + | _ -> None) + x + (translate_expr ctx context x e) + | Set_field (x, n, Non_float, y) -> + Memory.set_field (load_and_box ctx x) n (load_and_box ctx y) | Set_field (x, n, Float, y) -> Memory.float_array_set - (load x) + (load_and_box ctx x) (Constant.translate (Int (Targetint.of_int_warning_on_overflow n))) (load y) | Offset_ref (x, n) -> @@ -767,7 +937,11 @@ module Generate (Target : Target_sig.S) = struct 0 (Value.val_int Arith.(Value.int_val (Memory.field (load x) 0) + const (Int32.of_int n))) - | Array_set (x, y, z) -> Memory.array_set (load x) (load y) (load z) + | Array_set (x, y, z) -> + Memory.array_set + (load x) + (convert ~from:(get_var_type ctx y) ~into:(Int Normalized) (load y)) + (load_and_box ctx z) | Event loc -> event loc and translate_instrs ctx context l = @@ -777,7 +951,7 @@ module Generate (Target : Target_sig.S) = struct let* () = translate_instr ctx context i in translate_instrs ctx context rem - let parallel_renaming params args = + let parallel_renaming ~ctx params args = let rec visit visited prev s m x l = if not (Var.Set.mem x visited) then @@ -785,18 +959,21 @@ module Generate (Target : Target_sig.S) = struct let y = Var.Map.find x m in if Code.Var.compare x y = 0 then visited, None, l - else if Var.Set.mem y prev - then - let t = Code.Var.fresh () in - visited, Some (y, t), (x, t) :: l - else if Var.Set.mem y s - then - let visited, aliases, l = visit visited (Var.Set.add x prev) s m y l in - match aliases with - | Some (a, b) when Code.Var.compare a x = 0 -> - visited, None, (b, a) :: (x, y) :: l - | _ -> visited, aliases, (x, y) :: l - else visited, None, (x, y) :: l + else + let tx = get_var_type ctx x in + let ty = get_var_type ctx y in + if Var.Set.mem y prev + then + let t = Code.Var.fresh () in + visited, Some (y, ty, t, tx), (x, tx, t, tx) :: l + else if Var.Set.mem y s + then + let visited, aliases, l = visit visited (Var.Set.add x prev) s m y l in + match aliases with + | Some (a, ta, b, tb) when Code.Var.compare a x = 0 -> + visited, None, (b, tb, a, ta) :: (x, tx, y, ty) :: l + | _ -> visited, aliases, (x, tx, y, ty) :: l + else visited, None, (x, tx, y, ty) :: l else visited, None, l in let visit_all params args = @@ -815,9 +992,16 @@ module Generate (Target : Target_sig.S) = struct let l = visit_all params args in List.fold_left l - ~f:(fun continuation (y, x) -> + ~f:(fun continuation (y, ty, x, tx) -> let* () = continuation in - store ~always:true y (load x)) + store + ~always:true + ?typ: + (match ty with + | Typing.Int (Normalized | Unnormalized) -> Some I32 + | _ -> None) + y + (convert ~from:tx ~into:ty (load x))) ~init:(return ()) let exception_name = "ocaml_exception" @@ -976,7 +1160,7 @@ module Generate (Target : Target_sig.S) = struct match branch with | Branch cont -> translate_branch result_typ fall_through pc cont context | Return x -> ( - let* e = load x in + let* e = load_and_box ctx x in match fall_through with | `Return -> instr (Push e) | `Block _ | `Catch | `Skip -> instr (Return (Some e))) @@ -984,7 +1168,10 @@ module Generate (Target : Target_sig.S) = struct let context' = extend_context fall_through context in if_ { params = []; result = result_typ } - (Value.check_is_not_zero (load x)) + (match get_var_type ctx x with + | Int Normalized -> load x + | Int Unnormalized -> Arith.(load x lsl const 1l) + | _ -> Value.check_is_not_zero (load x)) (translate_branch result_typ fall_through pc cont1 context') (translate_branch result_typ fall_through pc cont2 context') | Stop -> ( @@ -999,7 +1186,9 @@ module Generate (Target : Target_sig.S) = struct assert (List.is_empty args); label_index context pc in - let* e = Value.int_val (load x) in + let* e = + convert ~from:(get_var_type ctx x) ~into:(Int Normalized) (load x) + in instr (Br_table (e, List.map ~f:dest l, dest a.(len - 1))) | Raise (x, _) -> ( let* e = load x in @@ -1030,7 +1219,7 @@ module Generate (Target : Target_sig.S) = struct then return () else let block = Addr.Map.find dst ctx.blocks in - parallel_renaming block.params args + parallel_renaming ~ctx block.params args in match fall_through with | `Block dst' when dst = dst' -> return () @@ -1184,7 +1373,8 @@ module Generate (Target : Target_sig.S) = struct ~should_export ~warn_on_unhandled_effect *) - ~deadcode_sentinal = + ~deadcode_sentinal + ~types = global_context.unit_name <- unit_name; let p, closures = Closure_conversion.f p in (* @@ -1194,6 +1384,7 @@ module Generate (Target : Target_sig.S) = struct { live = live_vars ; in_cps ; deadcode_sentinal + ; types ; blocks = p.blocks ; closures ; global_context @@ -1300,10 +1491,12 @@ let init = G.init let start () = make_context ~value_type:Gc_target.Type.value -let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal = +let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal ~global_flow_data = + let state, info = global_flow_data in + let types = Typing.f ~state ~info ~deadcode_sentinal p in let t = Timer.make () in let p = fix_switch_branches p in - let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in + let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in if times () then Format.eprintf " code gen.: %a@." Timer.print t; res diff --git a/compiler/lib-wasm/generate.mli b/compiler/lib-wasm/generate.mli index 7eb980822b..dc31cd455c 100644 --- a/compiler/lib-wasm/generate.mli +++ b/compiler/lib-wasm/generate.mli @@ -27,6 +27,7 @@ val f : -> live_vars:int array -> in_cps:Effects.in_cps -> deadcode_sentinal:Code.Var.t + -> global_flow_data:Global_flow.state * Global_flow.info -> Wasm_ast.var * (string * Javascript.expression) list val add_start_function : context:Code_generation.context -> Wasm_ast.var -> unit diff --git a/compiler/lib-wasm/target_sig.ml b/compiler/lib-wasm/target_sig.ml index 1182b30424..428327a3df 100644 --- a/compiler/lib-wasm/target_sig.ml +++ b/compiler/lib-wasm/target_sig.ml @@ -23,6 +23,7 @@ module type S = sig val allocate : tag:int -> deadcode_sentinal:Code.Var.t + -> load:(Code.Var.t -> expression) -> [ `Expr of Wasm_ast.expression | `Var of Wasm_ast.var ] list -> expression diff --git a/compiler/lib-wasm/typing.ml b/compiler/lib-wasm/typing.ml new file mode 100644 index 0000000000..1e7253fb6c --- /dev/null +++ b/compiler/lib-wasm/typing.ml @@ -0,0 +1,440 @@ +open! Stdlib +open Code +open Global_flow + +let debug = Debug.find "typing" + +module Integer = struct + type kind = + | Ref + | Normalized + | Unnormalized + + let join r r' = + match r, r' with + | Unnormalized, _ | _, Unnormalized -> Unnormalized + | Ref, Ref -> Ref + | _ -> Normalized +end + +type boxed_number = + | Int32 + | Int64 + | Nativeint + | Float + +type typ = + | Top + | Int of Integer.kind + | Number of boxed_number + | Tuple of typ array + | Bot + +module Domain = struct + type t = typ + + let rec join t t' = + match t, t' with + | Bot, t | t, Bot -> t + | Int r, Int r' -> Int (Integer.join r r') + | Number n, Number n' -> if Poly.equal n n' then t else Top + | Tuple t, Tuple t' -> + let l = Array.length t in + let l' = Array.length t' in + Tuple + (if l = l' + then Array.map2 ~f:join t t' + else + Array.init (max l l') ~f:(fun i -> + if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i))) + | Top, _ | _, Top -> Top + | (Int _ | Number _ | Tuple _), _ -> Top + + let join_set ?(others = false) f s = + if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot + + let rec equal t t' = + match t, t' with + | Top, Top | Bot, Bot -> true + | Int t, Int t' -> Poly.equal t t' + | Number t, Number t' -> Poly.equal t t' + | Tuple t, Tuple t' -> + Array.length t = Array.length t' && Array.for_all2 ~f:equal t t' + | (Top | Tuple _ | Int _ | Number _ | Bot), _ -> false + + let bot = Bot + + let depth_treshold = 4 + + let rec depth t = + match t with + | Top | Bot | Number _ | Int _ -> 0 + | Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0 + + let rec truncate depth t = + match t with + | Top | Bot | Number _ | Int _ -> t + | Tuple l -> + if depth = 0 + then Top + else Tuple (Array.map ~f:(fun t' -> truncate (depth - 1) t') l) + + let limit t = if depth t > depth_treshold then truncate depth_treshold t else t + + let box t = + match t with + | Int _ -> Int Ref + | _ -> t + + let rec print f t = + match t with + | Top -> Format.fprintf f "top" + | Bot -> Format.fprintf f "bot" + | Int k -> + Format.fprintf + f + "int{%s}" + (match k with + | Ref -> "ref" + | Normalized -> "normalized" + | Unnormalized -> "unnormalized") + | Number Int32 -> Format.fprintf f "int32" + | Number Int64 -> Format.fprintf f "int64" + | Number Nativeint -> Format.fprintf f "nativeint" + | Number Float -> Format.fprintf f "float" + | Tuple t -> + Format.fprintf + f + "(%a)" + (Format.pp_print_list ~pp_sep:(fun f () -> Format.fprintf f ",") print) + (Array.to_list t) +end + +let update_deps st { blocks; _ } = + let add_dep st x y = Var.Tbl.set st.deps y (x :: Var.Tbl.get st.deps y) in + Addr.Map.iter + (fun _ block -> + List.iter block.body ~f:(fun i -> + match i with + | Let (x, Block (_, lst, _, _)) -> Array.iter ~f:(fun y -> add_dep st x y) lst + | Let (x, Prim (Extern ("%int_and" | "%int_or" | "%int_xor"), lst)) -> + (* The return type of these primitives depend on the input type *) + List.iter + ~f:(fun p -> + match p with + | Pc _ -> () + | Pv y -> add_dep st x y) + lst + | _ -> ())) + blocks + +let mark_function_parameters { blocks; _ } = + let function_parameters = Var.Tbl.make () false in + let set x = Var.Tbl.set function_parameters x true in + Addr.Map.iter + (fun _ block -> + List.iter block.body ~f:(fun i -> + match i with + | Let (_, Closure (params, _, _)) -> List.iter ~f:set params + | _ -> ())) + blocks; + function_parameters + +type st = + { state : state + ; info : info + ; function_parameters : bool Var.Tbl.t + } + +let rec constant_type (c : constant) = + match c with + | Int _ -> Int Normalized + | Int32 _ -> Number Int32 + | Int64 _ -> Number Int64 + | NativeInt _ -> Number Nativeint + | Float _ -> Number Float + | Tuple (_, a, _) -> Tuple (Array.map ~f:(fun c' -> Domain.box (constant_type c')) a) + | _ -> Top + +let arg_type ~approx arg = + match arg with + | Pc c -> constant_type c + | Pv x -> Var.Tbl.get approx x + +let prim_type ~approx prim args = + match prim with + | "%int_add" | "%int_sub" | "%int_mul" | "%direct_int_mul" | "%int_lsl" | "%int_neg" -> + Int Unnormalized + | "%int_and" -> ( + match List.map ~f:(fun x -> arg_type ~approx x) args with + | [ (Bot | Int (Ref | Normalized)); _ ] | [ _; (Bot | Int (Ref | Normalized)) ] -> + Int Normalized + | _ -> Int Unnormalized) + | "%int_or" | "%int_xor" -> ( + match List.map ~f:(fun x -> arg_type ~approx x) args with + | [ (Bot | Int (Ref | Normalized)); (Bot | Int (Ref | Normalized)) ] -> + Int Normalized + | _ -> Int Unnormalized) + | "%int_lsr" + | "%int_asr" + | "%int_div" + | "%int_mod" + | "%direct_int_div" + | "%direct_int_mod" -> Int Normalized + | "caml_greaterthan" + | "caml_greaterequal" + | "caml_lessthan" + | "caml_lessequal" + | "caml_equal" + | "caml_compare" -> Int Ref + | "caml_int32_bswap" -> Number Int32 + | "caml_nativeint_bswap" -> Number Nativeint + | "caml_int64_bswap" -> Number Int64 + | "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" -> Int Ref + | "caml_string_get32" -> Number Int32 + | "caml_string_get64" -> Number Int64 + | "caml_bytes_get32" -> Number Int32 + | "caml_bytes_get64" -> Number Int64 + | "caml_lxm_next" -> Number Int64 + | "caml_ba_uint8_get32" -> Number Int32 + | "caml_ba_uint8_get64" -> Number Int64 + | "caml_nextafter_float" -> Number Float + | "caml_classify_float" -> Int Ref + | "caml_ldexp_float" | "caml_erf_float" | "caml_erfc_float" -> Number Float + | "caml_float_compare" -> Int Ref + | "caml_floatarray_unsafe_get" -> Number Float + | "caml_bytes_unsafe_get" + | "caml_string_unsafe_get" + | "caml_bytes_get" + | "caml_string_get" + | "caml_ml_string_length" + | "caml_ml_bytes_length" -> Int Normalized + | "%direct_obj_tag" -> Int Ref + | "caml_add_float" + | "caml_sub_float" + | "caml_mul_float" + | "caml_div_float" + | "caml_copysign_float" -> Number Float + | "caml_signbit_float" -> Int Normalized + | "caml_neg_float" + | "caml_abs_float" + | "caml_ceil_float" + | "caml_floor_float" + | "caml_trunc_float" + | "caml_round_float" + | "caml_sqrt_float" -> Number Float + | "caml_eq_float" + | "caml_neq_float" + | "caml_ge_float" + | "caml_le_float" + | "caml_gt_float" + | "caml_lt_float" + | "caml_int_of_float" -> Int Unnormalized + | "caml_float_of_int" + | "caml_cos_float" + | "caml_sin_float" + | "caml_tan_float" + | "caml_acos_float" + | "caml_asin_float" + | "caml_atan_float" + | "caml_atan2_float" + | "caml_cosh_float" + | "caml_sinh_float" + | "caml_tanh_float" + | "caml_acosh_float" + | "caml_asinh_float" + | "caml_atanh_float" + | "caml_cbrt_float" + | "caml_exp_float" + | "caml_exp2_float" + | "caml_log_float" + | "caml_expm1_float" + | "caml_log1p_float" + | "caml_log2_float" + | "caml_log10_float" + | "caml_power_float" + | "caml_hypot_float" + | "caml_fmod_float" -> Number Float + | "caml_int32_bits_of_float" -> Number Int32 + | "caml_int32_float_of_bits" -> Number Float + | "caml_int32_of_float" -> Number Int32 + | "caml_int32_to_float" -> Number Float + | "caml_int32_neg" + | "caml_int32_add" + | "caml_int32_sub" + | "caml_int32_mul" + | "caml_int32_and" + | "caml_int32_or" + | "caml_int32_xor" + | "caml_int32_div" + | "caml_int32_mod" + | "caml_int32_shift_left" + | "caml_int32_shift_right" + | "caml_int32_shift_right_unsigned" -> Number Int32 + | "caml_int32_to_int" -> Int Unnormalized + | "caml_int32_of_int" -> Number Int32 + | "caml_nativeint_of_int32" -> Number Nativeint + | "caml_nativeint_to_int32" -> Number Int32 + | "caml_int64_bits_of_float" -> Number Int64 + | "caml_int64_float_of_bits" -> Number Float + | "caml_int64_of_float" -> Number Int64 + | "caml_int64_to_float" -> Number Float + | "caml_int64_neg" + | "caml_int64_add" + | "caml_int64_sub" + | "caml_int64_mul" + | "caml_int64_and" + | "caml_int64_or" + | "caml_int64_xor" + | "caml_int64_div" + | "caml_int64_mod" + | "caml_int64_shift_left" + | "caml_int64_shift_right" + | "caml_int64_shift_right_unsigned" -> Number Int64 + | "caml_int64_to_int" -> Int Unnormalized + | "caml_int64_of_int" -> Number Int64 + | "caml_int64_to_int32" -> Number Int32 + | "caml_int64_of_int32" -> Number Int64 + | "caml_int64_to_nativeint" -> Number Nativeint + | "caml_int64_of_nativeint" -> Number Int64 + | "caml_nativeint_bits_of_float" -> Number Nativeint + | "caml_nativeint_float_of_bits" -> Number Float + | "caml_nativeint_of_float" -> Number Nativeint + | "caml_nativeint_to_float" -> Number Float + | "caml_nativeint_neg" + | "caml_nativeint_add" + | "caml_nativeint_sub" + | "caml_nativeint_mul" + | "caml_nativeint_and" + | "caml_nativeint_or" + | "caml_nativeint_xor" + | "caml_nativeint_div" + | "caml_nativeint_mod" + | "caml_nativeint_shift_left" + | "caml_nativeint_shift_right" + | "caml_nativeint_shift_right_unsigned" -> Number Nativeint + | "caml_nativeint_to_int" -> Int Unnormalized + | "caml_nativeint_of_int" -> Number Nativeint + | "caml_int_compare" -> Int Normalized + | _ -> Top + +let propagate st approx x : Domain.t = + match st.state.defs.(Var.idx x) with + | Phi { known; others; unit } -> + let res = Domain.join_set ~others (fun y -> Var.Tbl.get approx y) known in + let res = if unit then Domain.join (Int Unnormalized) res else res in + if Var.Tbl.get st.function_parameters x then Domain.box res else res + | Expr e -> ( + match e with + | Constant c -> constant_type c + | Closure _ -> Top + | Block (_, lst, _, _) -> + Tuple + (Array.mapi + ~f:(fun i y -> + match st.state.mutable_fields.(Var.idx x) with + | All_fields -> Top + | Some_fields s when IntSet.mem i s -> Top + | Some_fields _ | No_field -> + Domain.limit (Domain.box (Var.Tbl.get approx y))) + lst) + | Field (_, _, Float) -> Number Float + | Field (y, n, Non_float) -> ( + match Var.Tbl.get approx y with + | Tuple t -> if n < Array.length t then t.(n) else Bot + | Top -> Top + | _ -> Bot) + | Prim + ( Extern ("caml_check_bound" | "caml_check_bound_float" | "caml_check_bound_gen") + , [ Pv y; _ ] ) -> Var.Tbl.get approx y + | Prim ((Array_get | Extern "caml_array_unsafe_get"), [ Pv y; _ ]) -> ( + match Var.Tbl.get st.info.info_approximation y with + | Values { known; others } -> + Domain.join_set + ~others + (fun z -> + match st.state.defs.(Var.idx z) with + | Expr (Block (_, lst, _, _)) -> + let m = + match st.state.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields _ | All_fields -> true + in + if m + then Top + else + Domain.box + (Array.fold_left + ~f:(fun acc t -> Domain.join (Var.Tbl.get approx t) acc) + ~init:Domain.bot + lst) + | Expr (Closure _) -> Bot + | Phi _ | Expr _ -> assert false) + known + | Top -> Top) + | Prim (Array_get, _) -> Top + | Prim ((Vectlength | Not | IsInt | Eq | Neq | Lt | Le | Ult), _) -> Int Normalized + | Prim (Extern prim, args) -> prim_type ~approx prim args + | Special _ -> Top + | Apply { f; args; _ } -> ( + match Var.Tbl.get st.info.info_approximation f with + | Values { known; others } -> + Domain.join_set + ~others + (fun g -> + match st.state.defs.(Var.idx g) with + | Expr (Closure (params, _, _)) + when List.length args = List.length params -> + Domain.box + (Domain.join_set + (fun y -> Var.Tbl.get approx y) + (Var.Map.find g st.state.return_values)) + | Expr (Closure (_, _, _)) -> + (* The function is partially applied or over applied *) + Top + | Expr (Block _) -> Bot + | Phi _ | Expr _ -> assert false) + known + | Top -> Top)) + +module G = Dgraph.Make_Imperative (Var) (Var.ISet) (Var.Tbl) +module Solver = G.Solver (Domain) + +let solver st = + let associated_list h x = try Var.Hashtbl.find h x with Not_found -> [] in + let g = + { G.domain = st.state.vars + ; G.iter_children = + (fun f x -> + List.iter ~f (Var.Tbl.get st.state.deps x); + List.iter + ~f:(fun g -> List.iter ~f (associated_list st.state.function_call_sites g)) + (associated_list st.state.functions_from_returned_value x)) + } + in + Solver.f () g (propagate st) + +let f ~state ~info ~deadcode_sentinal p = + update_deps state p; + let function_parameters = mark_function_parameters p in + let typ = solver { state; info; function_parameters } in + Var.Tbl.set typ deadcode_sentinal (Int Normalized); + if debug () + then ( + Var.ISet.iter + (fun x -> + match state.defs.(Var.idx x) with + | Expr _ -> () + | Phi _ -> + let t = Var.Tbl.get typ x in + if not (Domain.equal t Top) + then Format.eprintf "%a: %a@." Var.print x Domain.print t) + state.vars; + Print.program + Format.err_formatter + (fun _ i -> + match i with + | Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get typ x) + | _ -> "") + p); + typ diff --git a/compiler/lib-wasm/typing.mli b/compiler/lib-wasm/typing.mli new file mode 100644 index 0000000000..1860b4ac7c --- /dev/null +++ b/compiler/lib-wasm/typing.mli @@ -0,0 +1,28 @@ +module Integer : sig + type kind = + | Ref + | Normalized + | Unnormalized +end + +type boxed_number = + | Int32 + | Int64 + | Nativeint + | Float + +type typ = + | Top + | Int of Integer.kind + | Number of boxed_number + | Tuple of typ array + | Bot + +val constant_type : Code.constant -> typ + +val f : + state:Global_flow.state + -> info:Global_flow.info + -> deadcode_sentinal:Code.Var.t + -> Code.program + -> typ Code.Var.Tbl.t diff --git a/compiler/lib/driver.ml b/compiler/lib/driver.ml index dfb65e5e08..4581c38d5d 100644 --- a/compiler/lib/driver.ml +++ b/compiler/lib/driver.ml @@ -93,16 +93,18 @@ let phi p = let ( +> ) f g x = g (f x) -let map_fst f (x, y, z) = f x, y, z +let map_fst f (x, y) = f x, y -let effects_and_exact_calls ~deadcode_sentinal (profile : Profile.t) p = +let effects_and_exact_calls ~keep_flow_data ~deadcode_sentinal (profile : Profile.t) p = let fast = match Config.effects (), profile with | (`Cps | `Double_translation), _ -> false | _, (O2 | O3) -> false | _, O1 -> true in - let info = Global_flow.f ~fast p in + let global_flow_data = Global_flow.f ~fast p in + let _, info = global_flow_data in + let global_flow_data = if keep_flow_data then Some global_flow_data else None in let pure_fun = Pure_fun.f p in let p, live_vars = if Config.Flag.globaldeadcode () && Config.Flag.deadcode () @@ -114,7 +116,8 @@ let effects_and_exact_calls ~deadcode_sentinal (profile : Profile.t) p = match Config.effects () with | `Cps | `Double_translation -> if debug () then Format.eprintf "Effects...@."; - Effects.f ~flow_info:info ~live_vars p + let p, trampolined_calls, in_cps = Effects.f ~flow_info:info ~live_vars p in + (p, (trampolined_calls, in_cps, None)) |> map_fst (match Config.target () with | `Wasm -> Fun.id @@ -124,8 +127,9 @@ let effects_and_exact_calls ~deadcode_sentinal (profile : Profile.t) p = Specialize.f ~function_arity:(fun f -> Global_flow.function_arity info f) p in ( p - , (Code.Var.Set.empty : Effects.trampolined_calls) - , (Code.Var.Set.empty : Effects.in_cps) ) + , ( (Code.Var.Set.empty : Effects.trampolined_calls) + , (Code.Var.Set.empty : Effects.in_cps) + , global_flow_data ) ) let print p = if debug () then Code.Print.program Format.err_formatter (fun _ _ -> "") p; @@ -613,7 +617,7 @@ let link_and_pack ?(standalone = true) ?(wrap_with_fun = `Iife) ?(link = `No) p |> pack ~wrap_with_fun ~standalone |> check_js -let optimize ~profile p = +let optimize ~profile ~keep_flow_data p = let deadcode_sentinal = (* If deadcode is disabled, this field is just fresh variable *) Code.Var.fresh_n "dummy" @@ -626,7 +630,7 @@ let optimize ~profile p = | O2 -> o2 | O3 -> o3) +> specialize_js_once_after - +> effects_and_exact_calls ~deadcode_sentinal profile + +> effects_and_exact_calls ~keep_flow_data ~deadcode_sentinal profile +> map_fst (match Config.target (), Config.effects () with | `JavaScript, `Disabled -> Generate_closure.f @@ -637,12 +641,20 @@ let optimize ~profile p = in if times () then Format.eprintf "Start Optimizing...@."; let t = Timer.make () in - let (program, variable_uses), trampolined_calls, in_cps = opt p in + let (program, variable_uses), (trampolined_calls, in_cps, global_flow_info) = opt p in let () = if times () then Format.eprintf " optimizations : %a@." Timer.print t in - { program; variable_uses; trampolined_calls; in_cps; deadcode_sentinal } + ( { program; variable_uses; trampolined_calls; in_cps; deadcode_sentinal } + , global_flow_info ) + +let optimize_for_wasm ~profile p = + let optimized_code, global_flow_data = optimize ~profile ~keep_flow_data:true p in + ( optimized_code + , match global_flow_data with + | Some data -> data + | None -> Global_flow.f ~fast:false optimized_code.program ) let full ~standalone ~wrap_with_fun ~profile ~link ~source_map ~formatter p = - let optimized_code = optimize ~profile p in + let optimized_code, _ = optimize ~profile ~keep_flow_data:false p in let exported_runtime = not standalone in let emit formatter = generate ~exported_runtime ~wrap_with_fun ~warn_on_unhandled_effect:standalone diff --git a/compiler/lib/driver.mli b/compiler/lib/driver.mli index 7c274322be..418e6d71e4 100644 --- a/compiler/lib/driver.mli +++ b/compiler/lib/driver.mli @@ -26,7 +26,10 @@ type optimized_result = ; deadcode_sentinal : Code.Var.t } -val optimize : profile:Profile.t -> Code.program -> optimized_result +val optimize_for_wasm : + profile:Profile.t + -> Code.program + -> optimized_result * (Global_flow.state * Global_flow.info) val f : ?standalone:bool diff --git a/compiler/lib/global_flow.ml b/compiler/lib/global_flow.ml index 1e167a47ef..f5bfccb985 100644 --- a/compiler/lib/global_flow.ml +++ b/compiler/lib/global_flow.ml @@ -79,30 +79,36 @@ type def = | Phi of { known : Var.Set.t (* Known arguments *) ; others : bool (* Can there be other arguments *) + ; unit : bool (* Whether we are propagating unit (used for typing) *) } -let undefined = Phi { known = Var.Set.empty; others = false } +let undefined = Phi { known = Var.Set.empty; others = false; unit = false } let is_undefined d = match d with | Expr _ -> false - | Phi { known; others } -> Var.Set.is_empty known && not others + | Phi { known; others; unit } -> Var.Set.is_empty known && (not others) && not unit type escape_status = | Escape | Escape_constant (* Escapes but we know the value is not modified *) | No +type mutable_fields = + | No_field + | Some_fields of IntSet.t + | All_fields + type state = { vars : Var.ISet.t (* Set of all veriables considered *) ; deps : Var.t list Var.Tbl.t (* Dependency between variables *) ; defs : def array (* Definition of each variable *) ; variable_may_escape : escape_status array (* Any value bound to this variable may escape *) - ; variable_possibly_mutable : Var.ISet.t + ; variable_mutable_fields : mutable_fields array (* Any value bound to this variable may be mutable *) ; may_escape : escape_status array (* This value may escape *) - ; possibly_mutable : Var.ISet.t (* This value may be mutable *) + ; mutable_fields : mutable_fields array (* This value may be mutable *) ; return_values : Var.Set.t Var.Map.t (* Set of variables holding return values of each function *) ; functions_from_returned_value : Var.t list Var.Hashtbl.t @@ -136,13 +142,22 @@ let add_assign_def st x y = let idx = Var.idx x in match st.defs.(idx) with | Expr _ -> assert false - | Phi { known; others } -> st.defs.(idx) <- Phi { known = Var.Set.add y known; others } + | Phi { known; others; unit } -> + st.defs.(idx) <- Phi { known = Var.Set.add y known; others; unit } + +let add_unit_def st x = + add_var st x; + let idx = Var.idx x in + match st.defs.(idx) with + | Expr _ -> assert false + | Phi { known; others; _ } -> st.defs.(idx) <- Phi { known; others; unit = true } let add_param_def st x = add_var st x; let idx = Var.idx x in assert (is_undefined st.defs.(idx)); - if st.fast then st.defs.(idx) <- Phi { known = Var.Set.empty; others = true } + if st.fast + then st.defs.(idx) <- Phi { known = Var.Set.empty; others = true; unit = false } let rec arg_deps st ?ignore params args = match params, args with @@ -150,7 +165,7 @@ let rec arg_deps st ?ignore params args = (* This is to deal with the [else] clause of a conditional, where we know that the value of the tested variable is 0. *) (match ignore with - | Some y' when Var.equal y y' -> () + | Some y' when Var.equal y y' -> add_unit_def st x | _ -> add_assign_def st x y); arg_deps st params args | [], [] -> () @@ -162,7 +177,14 @@ let cont_deps blocks st ?ignore (pc, args) = let do_escape st level x = st.variable_may_escape.(Var.idx x) <- level -let possibly_mutable st x = Var.ISet.add st.variable_possibly_mutable x +let possibly_mutable st x = st.variable_mutable_fields.(Var.idx x) <- All_fields + +let field_possibly_mutable st x n = + match st.variable_mutable_fields.(Var.idx x) with + | No_field -> st.variable_mutable_fields.(Var.idx x) <- Some_fields (IntSet.singleton n) + | Some_fields s -> + st.variable_mutable_fields.(Var.idx x) <- Some_fields (IntSet.add n s) + | All_fields -> () let expr_deps blocks st x e = match e with @@ -267,7 +289,10 @@ let program_deps st { start; blocks; _ } = add_expr_def st x e; expr_deps blocks st x e | Assign (x, y) -> add_assign_def st x y - | Set_field (x, _, _, y) | Array_set (x, _, y) -> + | Set_field (x, n, _, y) -> + field_possibly_mutable st x n; + do_escape st Escape y + | Array_set (x, _, y) -> possibly_mutable st x; do_escape st Escape y | Event _ | Offset_ref _ -> ()); @@ -308,7 +333,8 @@ let program_deps st { start; blocks; _ } = | Expr _ | Phi _ -> ()) | Pushtrap (cont, x, cont_h) -> add_var st x; - st.defs.(Var.idx x) <- Phi { known = Var.Set.empty; others = true }; + st.defs.(Var.idx x) <- + Phi { known = Var.Set.empty; others = true; unit = false }; cont_deps blocks st cont_h; cont_deps blocks st cont) blocks @@ -360,14 +386,15 @@ module Domain = struct Array.iter ~f:(fun y -> variable_escape ~update ~st ~approx s y) a; match s, mut with | Escape, Maybe_mutable -> - Var.ISet.add st.possibly_mutable x; + st.mutable_fields.(Var.idx x) <- All_fields; update ~children:true x | (Escape_constant | No), _ | Escape, Immutable -> ()) | Expr (Closure (params, _, _)) -> List.iter ~f:(fun y -> (match st.defs.(Var.idx y) with - | Phi { known; _ } -> st.defs.(Var.idx y) <- Phi { known; others = true } + | Phi { known; _ } -> + st.defs.(Var.idx y) <- Phi { known; others = true; unit = false } | Expr _ -> assert false); update ~children:false y) params; @@ -405,18 +432,28 @@ module Domain = struct s (if o then others else bot) - let mark_mutable ~update ~st a = + let mark_mutable ~update ~st a mutable_fields = match a with | Top -> () | Values { known; _ } -> Var.Set.iter (fun x -> match st.defs.(Var.idx x) with - | Expr (Block (_, _, _, Maybe_mutable)) -> - if not (Var.ISet.mem st.possibly_mutable x) - then ( - Var.ISet.add st.possibly_mutable x; - update ~children:true x) + | Expr (Block (_, _, _, Maybe_mutable)) -> ( + match st.mutable_fields.(Var.idx x), mutable_fields with + | _, No_field -> () + | No_field, _ -> + st.mutable_fields.(Var.idx x) <- mutable_fields; + update ~children:true x + | Some_fields s, Some_fields s' -> + if IntSet.exists (fun i -> not (IntSet.mem i s)) s' + then ( + st.mutable_fields.(Var.idx x) <- Some_fields (IntSet.union s s'); + update ~children:true x) + | Some_fields _, All_fields -> + st.mutable_fields.(Var.idx x) <- All_fields; + update ~children:true x + | All_fields, _ -> ()) | Expr (Block (_, _, _, Immutable)) | Expr (Closure _) -> () | Phi _ | Expr _ -> assert false) known @@ -424,7 +461,7 @@ end let propagate st ~update approx x = match st.defs.(Var.idx x) with - | Phi { known; others } -> + | Phi { known; others; _ } -> Domain.join_set ~update ~st ~approx ~others (fun y -> Var.Tbl.get approx y) known | Expr e -> ( match e with @@ -452,7 +489,12 @@ let propagate st ~update approx x = | Some tags -> List.mem ~eq:Int.equal t tags | None -> true -> let t = a.(n) in - let m = Var.ISet.mem st.possibly_mutable z in + let m = + match st.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields s -> IntSet.mem n s + | All_fields -> true + in if not m then add_dep st x z; add_dep st x t; let a = Var.Tbl.get approx t in @@ -480,7 +522,11 @@ let propagate st ~update approx x = (fun z -> match st.defs.(Var.idx z) with | Expr (Block (_, lst, _, _)) -> - let m = Var.ISet.mem st.possibly_mutable z in + let m = + match st.mutable_fields.(Var.idx z) with + | No_field -> false + | Some_fields _ | All_fields -> true + in if not m then add_dep st x z; Array.iter ~f:(fun t -> add_dep st x t) lst; let a = @@ -574,8 +620,9 @@ let propagate st ~update approx x = (match st.variable_may_escape.(Var.idx x) with | (Escape | Escape_constant) as s -> Domain.approx_escape ~update ~st ~approx s res | No -> ()); - if Var.ISet.mem st.variable_possibly_mutable x - then Domain.mark_mutable ~update ~st res; + (match st.variable_mutable_fields.(Var.idx x) with + | No_field -> () + | (Some_fields _ | All_fields) as s -> Domain.mark_mutable ~update ~st res s); res | Top -> Top @@ -653,9 +700,9 @@ let f ~fast p = let deps = Var.Tbl.make () [] in let defs = Array.make nv undefined in let variable_may_escape = Array.make nv No in - let variable_possibly_mutable = Var.ISet.empty () in + let variable_mutable_fields = Array.make nv No_field in let may_escape = Array.make nv No in - let possibly_mutable = Var.ISet.empty () in + let mutable_fields = Array.make nv No_field in let functions_from_returned_value = Var.Hashtbl.create 128 in Var.Map.iter (fun f s -> Var.Set.iter (fun x -> add_to_list functions_from_returned_value x f) s) @@ -667,9 +714,9 @@ let f ~fast p = ; return_values = rets ; functions_from_returned_value ; variable_may_escape - ; variable_possibly_mutable + ; variable_mutable_fields ; may_escape - ; possibly_mutable + ; mutable_fields ; known_cases = Var.Hashtbl.create 16 ; applied_functions = VarPairTbl.create 16 ; fast @@ -698,13 +745,28 @@ let f ~fast p = match a with | Top -> Format.fprintf f "top" | Values _ -> + let print_mutable_fields f s = + match s with + | No_field -> Format.fprintf f "no" + | Some_fields s -> + Format.fprintf + f + "{%a}" + (Format.pp_print_list + ~pp_sep:(fun f () -> Format.fprintf f ", ") + (fun f i -> Format.fprintf f "%d" i)) + (IntSet.elements s) + | All_fields -> Format.fprintf f "yes" + in Format.fprintf f - "%a mut:%b vmut:%b vesc:%s esc:%s" + "%a mut:%a vmut:%a vesc:%s esc:%s" (print_approx st) a - (Var.ISet.mem st.possibly_mutable x) - (Var.ISet.mem st.variable_possibly_mutable x) + print_mutable_fields + st.mutable_fields.(Var.idx x) + print_mutable_fields + st.variable_mutable_fields.(Var.idx x) (match st.variable_may_escape.(Var.idx x) with | Escape -> "Y" | Escape_constant -> "y" @@ -723,12 +785,13 @@ let f ~fast p = | Escape_constant | Escape -> Var.ISet.add info_may_escape (Var.of_idx i) | No -> ()) may_escape; - { info_defs = defs - ; info_approximation = approximation - ; info_variable_may_escape - ; info_may_escape - ; info_return_vals = rets - } + ( st + , { info_defs = defs + ; info_approximation = approximation + ; info_variable_may_escape + ; info_may_escape + ; info_return_vals = rets + } ) let exact_call info f n = match Var.Tbl.get info.info_approximation f with diff --git a/compiler/lib/global_flow.mli b/compiler/lib/global_flow.mli index 61f5dbfb6a..c8b8592167 100644 --- a/compiler/lib/global_flow.mli +++ b/compiler/lib/global_flow.mli @@ -22,6 +22,7 @@ type def = | Phi of { known : Var.Set.t (* Known arguments *) ; others : bool (* Can there be other arguments *) + ; unit : bool (* Whether we are propagating unit (used for typing) *) } type approx = @@ -44,7 +45,40 @@ type info = ; info_return_vals : Var.Set.t Var.Map.t } -val f : fast:bool -> Code.program -> info +type mutable_fields = + | No_field + | Some_fields of Stdlib.IntSet.t + | All_fields + +module VarPairTbl : Hashtbl.S with type key = Var.t * Var.t + +type state = + { vars : Var.ISet.t (* Set of all veriables considered *) + ; deps : Var.t list Var.Tbl.t (* Dependency between variables *) + ; defs : def array (* Definition of each variable *) + ; variable_may_escape : escape_status array + (* Any value bound to this variable may escape *) + ; variable_mutable_fields : mutable_fields array + (* Any value bound to this variable may be mutable *) + ; may_escape : escape_status array (* This value may escape *) + ; mutable_fields : mutable_fields array (* This value may be mutable *) + ; return_values : Var.Set.t Var.Map.t + (* Set of variables holding return values of each function *) + ; functions_from_returned_value : Var.t list Var.Hashtbl.t + (* Functions associated to each return value *) + ; known_cases : int list Var.Hashtbl.t + (* Possible tags for a block after a [switch]. This is used to + get a more precise approximation of the effect of a field + access [Field] *) + ; applied_functions : unit VarPairTbl.t + (* Functions that have been already considered at a call site. + This is to avoid repeated computations *) + ; function_call_sites : Var.t list Var.Hashtbl.t + (* Known call sites of each functions *) + ; fast : bool + } + +val f : fast:bool -> Code.program -> state * info val exact_call : info -> Var.t -> int -> bool