Skip to content

Commit 875e0df

Browse files
committed
Wasm: specialization of number comparisons
1 parent 6732e36 commit 875e0df

File tree

7 files changed

+226
-53
lines changed

7 files changed

+226
-53
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ module Generate (Target : Target_sig.S) = struct
6767
type repr =
6868
| Value
6969
| Float
70+
| Int
7071
| Int32
7172
| Nativeint
7273
| Int64
@@ -75,24 +76,23 @@ module Generate (Target : Target_sig.S) = struct
7576
match r with
7677
| Value -> Type.value
7778
| Float -> F64
78-
| Int32 -> I32
79-
| Nativeint -> I32
79+
| Int | Int32 | Nativeint -> I32
8080
| Int64 -> I64
8181

8282
let specialized_primitive_type (_, params, result) =
8383
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }
8484

8585
let box_value r e =
8686
match r with
87-
| Value -> e
87+
| Value | Int -> e
8888
| Float -> Memory.box_float e
8989
| Int32 -> Memory.box_int32 e
9090
| Nativeint -> Memory.box_nativeint e
9191
| Int64 -> Memory.box_int64 e
9292

9393
let unbox_value r e =
9494
match r with
95-
| Value -> e
95+
| Value | Int -> e
9696
| Float -> Memory.unbox_float e
9797
| Int32 -> Memory.unbox_int32 e
9898
| Nativeint -> Memory.unbox_nativeint e
@@ -105,9 +105,9 @@ module Generate (Target : Target_sig.S) = struct
105105
[ "caml_int32_bswap", (`Pure, [ Int32 ], Int32)
106106
; "caml_nativeint_bswap", (`Pure, [ Nativeint ], Nativeint)
107107
; "caml_int64_bswap", (`Pure, [ Int64 ], Int64)
108-
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Value)
109-
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Value)
110-
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Value)
108+
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Int)
109+
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Int)
110+
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Int)
111111
; "caml_string_get32", (`Mutator, [ Value; Value ], Int32)
112112
; "caml_string_get64", (`Mutator, [ Value; Value ], Int64)
113113
; "caml_bytes_get32", (`Mutator, [ Value; Value ], Int32)
@@ -124,7 +124,7 @@ module Generate (Target : Target_sig.S) = struct
124124
; "caml_ldexp_float", (`Pure, [ Float; Value ], Float)
125125
; "caml_erf_float", (`Pure, [ Float ], Float)
126126
; "caml_erfc_float", (`Pure, [ Float ], Float)
127-
; "caml_float_compare", (`Pure, [ Float; Float ], Value)
127+
; "caml_float_compare", (`Pure, [ Float; Float ], Int)
128128
];
129129
h
130130

@@ -283,6 +283,38 @@ module Generate (Target : Target_sig.S) = struct
283283
(transl_prim_arg ctx ?typ:tz z)
284284
| _ -> invalid_arity name l ~expected:3)
285285

286+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
287+
register_prim name `Mutable (fun ctx _ l ->
288+
match l with
289+
| [ x; y ] -> (
290+
let x' = transl_prim_arg ctx x in
291+
let y' = transl_prim_arg ctx y in
292+
match get_type ctx x, get_type ctx y with
293+
| Int _, Int _ -> cmp_int ctx x y
294+
| Number Int32, Number Int32 ->
295+
let* x' = Memory.unbox_int32 x' in
296+
let* y' = Memory.unbox_int32 y' in
297+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
298+
| Number Nativeint, Number Nativeint ->
299+
let* x' = Memory.unbox_nativeint x' in
300+
let* y' = Memory.unbox_nativeint y' in
301+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
302+
| Number Int64, Number Int64 ->
303+
let* x' = Memory.unbox_int64 x' in
304+
let* y' = Memory.unbox_int64 y' in
305+
return (W.BinOp (I64 cmp_boxed_int, x', y'))
306+
| Number Float, Number Float -> float_comparison cmp_float x' y'
307+
| _ ->
308+
let* f =
309+
register_import
310+
~name
311+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
312+
in
313+
let* x' = x' in
314+
let* y' = y' in
315+
return (W.Call (f, [ x'; y' ])))
316+
| _ -> invalid_arity name l ~expected:2)
317+
286318
let () =
287319
register_bin_prim
288320
"caml_array_unsafe_get"
@@ -764,7 +796,93 @@ module Generate (Target : Target_sig.S) = struct
764796
l
765797
~init:(return [])
766798
in
767-
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l)
799+
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l);
800+
register_comparison
801+
"caml_greaterthan"
802+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x < y)) x y)
803+
(Gt S)
804+
Gt;
805+
register_comparison
806+
"caml_greaterequal"
807+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x <= y)) x y)
808+
(Ge S)
809+
Ge;
810+
register_comparison
811+
"caml_lessthan"
812+
(fun ctx x y -> translate_int_comparison ctx Arith.( < ) x y)
813+
(Lt S)
814+
Lt;
815+
register_comparison
816+
"caml_lessequal"
817+
(fun ctx x y -> translate_int_comparison ctx Arith.( <= ) x y)
818+
(Le S)
819+
Le;
820+
register_comparison
821+
"caml_equal"
822+
(fun ctx x y -> translate_int_equality ctx Arith.( = ) Value.eq x y)
823+
Eq
824+
Eq;
825+
register_comparison
826+
"caml_notequal"
827+
(fun ctx x y -> translate_int_equality ctx Arith.( <> ) Value.neq x y)
828+
Ne
829+
Ne;
830+
register_prim "caml_compare" `Mutable (fun ctx _ l ->
831+
match l with
832+
| [ x; y ] -> (
833+
let x' = transl_prim_arg ctx x in
834+
let y' = transl_prim_arg ctx y in
835+
match get_type ctx x, get_type ctx y with
836+
| Int _, Int _ ->
837+
Arith.(
838+
(Value.int_val y' < Value.int_val x')
839+
- (Value.int_val x' < Value.int_val y'))
840+
| Number Int32, Number Int32 ->
841+
let* f =
842+
register_import
843+
~name:"caml_int32_compare"
844+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
845+
in
846+
let* x' = Memory.unbox_int32 x' in
847+
let* y' = Memory.unbox_int32 y' in
848+
return (W.Call (f, [ x'; y' ]))
849+
| Number Nativeint, Number Nativeint ->
850+
let* f =
851+
register_import
852+
~name:"caml_nativeint_compare"
853+
(Fun (Type.primitive_type 2))
854+
in
855+
let* x' = Memory.unbox_nativeint x' in
856+
let* y' = Memory.unbox_nativeint y' in
857+
return (W.Call (f, [ x'; y' ]))
858+
| Number Int64, Number Int64 ->
859+
let* f =
860+
register_import
861+
~name:"caml_int64_compare"
862+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
863+
in
864+
let* x' = Memory.unbox_int64 x' in
865+
let* y' = Memory.unbox_int64 y' in
866+
return (W.Call (f, [ x'; y' ]))
867+
| Number Float, Number Float ->
868+
let* f =
869+
register_import
870+
~name:"caml_float_compare"
871+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
872+
in
873+
let* x' = Memory.unbox_int64 x' in
874+
let* y' = Memory.unbox_int64 y' in
875+
return (W.Call (f, [ x'; y' ]))
876+
| _ ->
877+
let* f =
878+
register_import
879+
~name:"caml_compare"
880+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
881+
in
882+
let* x' = x' in
883+
let* y' = y' in
884+
return (W.Call (f, [ x'; y' ])))
885+
| _ -> invalid_arity "caml_compare" l ~expected:2)
768886

769887
let rec translate_expr ctx context x e =
770888
match e with

compiler/lib-wasm/typing.ml

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ module Integer = struct
1515
| Unnormalized, _ | _, Unnormalized -> Unnormalized
1616
| Ref, Ref -> Ref
1717
| _ -> Normalized
18+
19+
let sub r r' =
20+
match r, r' with
21+
| _, Unnormalized -> true
22+
| Ref, _ -> true
23+
| Normalized, Normalized -> true
24+
| Unnormalized, (Ref | Normalized) -> false
25+
| Normalized, Ref -> false
1826
end
1927

2028
type boxed_number =
@@ -62,6 +70,21 @@ module Domain = struct
6270
Array.length t = Array.length t' && Array.for_all2 ~f:equal t t'
6371
| (Top | Tuple _ | Int _ | Number _ | Bot), _ -> false
6472

73+
let rec sub t t' =
74+
match t, t' with
75+
| _, Top | Bot, _ -> true
76+
| Top, _ | _, Bot -> false
77+
| Int t, Int t' -> Integer.sub t t'
78+
| Number t, Number t' -> Poly.equal t t'
79+
| Tuple t, Tuple t' ->
80+
Array.length t <= Array.length t'
81+
&&
82+
let rec compare t t' i l =
83+
i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l)
84+
in
85+
compare t t' 0 (Array.length t)
86+
| (Int _ | Number _ | Tuple _), _ -> false
87+
6588
let bot = Bot
6689

6790
let depth_treshold = 4
@@ -186,11 +209,13 @@ let prim_type ~approx prim args =
186209
| "caml_lessthan"
187210
| "caml_lessequal"
188211
| "caml_equal"
189-
| "caml_compare" -> Int Ref
212+
| "caml_notequal"
213+
| "caml_compare" -> Int Normalized
190214
| "caml_int32_bswap" -> Number Int32
191215
| "caml_nativeint_bswap" -> Number Nativeint
192216
| "caml_int64_bswap" -> Number Int64
193-
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" -> Int Ref
217+
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" ->
218+
Int Normalized
194219
| "caml_string_get32" -> Number Int32
195220
| "caml_string_get64" -> Number Int64
196221
| "caml_bytes_get32" -> Number Int32
@@ -201,7 +226,7 @@ let prim_type ~approx prim args =
201226
| "caml_nextafter_float" -> Number Float
202227
| "caml_classify_float" -> Int Ref
203228
| "caml_ldexp_float" | "caml_erf_float" | "caml_erfc_float" -> Number Float
204-
| "caml_float_compare" -> Int Ref
229+
| "caml_float_compare" -> Int Normalized
205230
| "caml_floatarray_unsafe_get" -> Number Float
206231
| "caml_bytes_unsafe_get"
207232
| "caml_string_unsafe_get"
@@ -414,6 +439,40 @@ let solver st =
414439
in
415440
Solver.f () g (propagate st)
416441

442+
let print_opt typ f e =
443+
match e with
444+
| Prim
445+
( Extern
446+
( "caml_greaterthan"
447+
| "caml_greaterequal"
448+
| "caml_lessthan"
449+
| "caml_lessequal"
450+
| "caml_equal"
451+
| "caml_compare" )
452+
, l ) ->
453+
if
454+
List.exists
455+
~f:(fun t' ->
456+
List.for_all
457+
~f:(fun p ->
458+
let t =
459+
match p with
460+
| Pc c -> constant_type c
461+
| Pv x -> Var.Tbl.get typ x
462+
in
463+
Domain.sub t t')
464+
l)
465+
[ Int Ref
466+
; Int Normalized
467+
; Int Unnormalized
468+
; Number Int32
469+
; Number Int64
470+
; Number Nativeint
471+
; Number Float
472+
]
473+
then Format.fprintf f " OPT"
474+
| _ -> ()
475+
417476
let f ~state ~info ~deadcode_sentinal p =
418477
update_deps state p;
419478
let function_parameters = mark_function_parameters p in
@@ -434,7 +493,8 @@ let f ~state ~info ~deadcode_sentinal p =
434493
Format.err_formatter
435494
(fun _ i ->
436495
match i with
437-
| Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get typ x)
496+
| Instr (Let (x, e)) ->
497+
Format.asprintf "{%a}%a" Domain.print (Var.Tbl.get typ x) (print_opt typ) e
438498
| _ -> "")
439499
p);
440500
typ

runtime/js/compare.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function caml_compare_val(a, b, total) {
251251
b = b[i];
252252
}
253253
}
254-
//Provides: caml_compare (const, const)
254+
//Provides: caml_compare mutable (const, const)
255255
//Requires: caml_compare_val
256256
function caml_compare(a, b) {
257257
return caml_compare_val(a, b, true);

runtime/wasm/compare.wat

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -556,53 +556,49 @@
556556
(i32.const 0))
557557

558558
(func (export "caml_compare")
559-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
559+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
560560
(local $res i32)
561561
(local.set $res
562562
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 1)))
563563
(if (i32.lt_s (local.get $res) (i32.const 0))
564-
(then (return (ref.i31 (i32.const -1)))))
564+
(then (return (i32.const -1))))
565565
(if (i32.gt_s (local.get $res) (i32.const 0))
566-
(then (return (ref.i31 (i32.const 1)))))
567-
(ref.i31 (i32.const 0)))
566+
(then (return (i32.const 1))))
567+
(i32.const 0))
568568

569569
(func (export "caml_equal")
570-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
571-
(ref.i31
572-
(i32.eqz
573-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
570+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
571+
(i32.eqz
572+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
574573

575574
(func (export "caml_notequal")
576-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
577-
(ref.i31
578-
(i32.ne (i32.const 0)
579-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
575+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
576+
(i32.ne (i32.const 0)
577+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
580578

581579
(func (export "caml_lessthan")
582-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
580+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
583581
(local $res i32)
584582
(local.set $res
585583
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
586-
(ref.i31
587-
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
588-
(i32.ne (local.get $res) (global.get $unordered)))))
584+
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
585+
(i32.ne (local.get $res) (global.get $unordered))))
589586

590587
(func (export "caml_lessequal")
591-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
588+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
592589
(local $res i32)
593590
(local.set $res
594591
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
595-
(ref.i31
596-
(i32.and (i32.le_s (local.get $res) (i32.const 0))
597-
(i32.ne (local.get $res) (global.get $unordered)))))
592+
(i32.and (i32.le_s (local.get $res) (i32.const 0))
593+
(i32.ne (local.get $res) (global.get $unordered))))
598594

599595
(func (export "caml_greaterthan")
600-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
601-
(ref.i31 (i32.lt_s (i32.const 0)
602-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
596+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
597+
(i32.lt_s (i32.const 0)
598+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
603599

604600
(func (export "caml_greaterequal")
605-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
606-
(ref.i31 (i32.le_s (i32.const 0)
607-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
601+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
602+
(i32.le_s (i32.const 0)
603+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
608604
)

0 commit comments

Comments
 (0)