Skip to content

Commit 065a877

Browse files
committed
Wasm: specialization of number comparisons
1 parent edde368 commit 065a877

File tree

7 files changed

+189
-53
lines changed

7 files changed

+189
-53
lines changed

compiler/lib-wasm/generate.ml

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ module Generate (Target : Target_sig.S) = struct
6868
type repr =
6969
| Value
7070
| Float
71+
| Int
7172
| Int32
7273
| Nativeint
7374
| Int64
@@ -76,24 +77,23 @@ module Generate (Target : Target_sig.S) = struct
7677
match r with
7778
| Value -> Type.value
7879
| Float -> F64
79-
| Int32 -> I32
80-
| Nativeint -> I32
80+
| Int | Int32 | Nativeint -> I32
8181
| Int64 -> I64
8282

8383
let specialized_primitive_type (_, params, result) =
8484
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }
8585

8686
let box_value r e =
8787
match r with
88-
| Value -> e
88+
| Value | Int -> e
8989
| Float -> Memory.box_float e
9090
| Int32 -> Memory.box_int32 e
9191
| Nativeint -> Memory.box_nativeint e
9292
| Int64 -> Memory.box_int64 e
9393

9494
let unbox_value r e =
9595
match r with
96-
| Value -> e
96+
| Value | Int -> e
9797
| Float -> Memory.unbox_float e
9898
| Int32 -> Memory.unbox_int32 e
9999
| Nativeint -> Memory.unbox_nativeint e
@@ -106,9 +106,9 @@ module Generate (Target : Target_sig.S) = struct
106106
[ "caml_int32_bswap", (`Pure, [ Int32 ], Int32)
107107
; "caml_nativeint_bswap", (`Pure, [ Nativeint ], Nativeint)
108108
; "caml_int64_bswap", (`Pure, [ Int64 ], Int64)
109-
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Value)
110-
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Value)
111-
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Value)
109+
; "caml_int32_compare", (`Pure, [ Int32; Int32 ], Int)
110+
; "caml_nativeint_compare", (`Pure, [ Nativeint; Nativeint ], Int)
111+
; "caml_int64_compare", (`Pure, [ Int64; Int64 ], Int)
112112
; "caml_string_get32", (`Mutator, [ Value; Value ], Int32)
113113
; "caml_string_get64", (`Mutator, [ Value; Value ], Int64)
114114
; "caml_bytes_get32", (`Mutator, [ Value; Value ], Int32)
@@ -125,7 +125,7 @@ module Generate (Target : Target_sig.S) = struct
125125
; "caml_ldexp_float", (`Pure, [ Float; Value ], Float)
126126
; "caml_erf_float", (`Pure, [ Float ], Float)
127127
; "caml_erfc_float", (`Pure, [ Float ], Float)
128-
; "caml_float_compare", (`Pure, [ Float; Float ], Value)
128+
; "caml_float_compare", (`Pure, [ Float; Float ], Int)
129129
];
130130
h
131131

@@ -300,6 +300,38 @@ module Generate (Target : Target_sig.S) = struct
300300
(transl_prim_arg ctx ?typ:tz z)
301301
| _ -> invalid_arity name l ~expected:3)
302302

303+
let register_comparison name cmp_int cmp_boxed_int cmp_float =
304+
register_prim name `Mutable (fun ctx _ l ->
305+
match l with
306+
| [ x; y ] -> (
307+
let x' = transl_prim_arg ctx x in
308+
let y' = transl_prim_arg ctx y in
309+
match get_type ctx x, get_type ctx y with
310+
| Int _, Int _ -> cmp_int ctx x y
311+
| Number Int32, Number Int32 ->
312+
let* x' = Memory.unbox_int32 x' in
313+
let* y' = Memory.unbox_int32 y' in
314+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
315+
| Number Nativeint, Number Nativeint ->
316+
let* x' = Memory.unbox_nativeint x' in
317+
let* y' = Memory.unbox_nativeint y' in
318+
return (W.BinOp (I32 cmp_boxed_int, x', y'))
319+
| Number Int64, Number Int64 ->
320+
let* x' = Memory.unbox_int64 x' in
321+
let* y' = Memory.unbox_int64 y' in
322+
return (W.BinOp (I64 cmp_boxed_int, x', y'))
323+
| Number Float, Number Float -> float_comparison cmp_float x' y'
324+
| _ ->
325+
let* f =
326+
register_import
327+
~name
328+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
329+
in
330+
let* x' = x' in
331+
let* y' = y' in
332+
return (W.Call (f, [ x'; y' ])))
333+
| _ -> invalid_arity name l ~expected:2)
334+
303335
let () =
304336
register_bin_prim
305337
"caml_array_unsafe_get"
@@ -781,7 +813,93 @@ module Generate (Target : Target_sig.S) = struct
781813
l
782814
~init:(return [])
783815
in
784-
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l)
816+
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal ~load l);
817+
register_comparison
818+
"caml_greaterthan"
819+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x < y)) x y)
820+
(Gt S)
821+
Gt;
822+
register_comparison
823+
"caml_greaterequal"
824+
(fun ctx x y -> translate_int_comparison ctx (fun y x -> Arith.(x <= y)) x y)
825+
(Ge S)
826+
Ge;
827+
register_comparison
828+
"caml_lessthan"
829+
(fun ctx x y -> translate_int_comparison ctx Arith.( < ) x y)
830+
(Lt S)
831+
Lt;
832+
register_comparison
833+
"caml_lessequal"
834+
(fun ctx x y -> translate_int_comparison ctx Arith.( <= ) x y)
835+
(Le S)
836+
Le;
837+
register_comparison
838+
"caml_equal"
839+
(fun ctx x y -> translate_int_equality ctx ~negate:false x y)
840+
Eq
841+
Eq;
842+
register_comparison
843+
"caml_notequal"
844+
(fun ctx x y -> translate_int_equality ctx ~negate:true x y)
845+
Ne
846+
Ne;
847+
register_prim "caml_compare" `Mutable (fun ctx _ l ->
848+
match l with
849+
| [ x; y ] -> (
850+
let x' = transl_prim_arg ctx x in
851+
let y' = transl_prim_arg ctx y in
852+
match get_type ctx x, get_type ctx y with
853+
| Int _, Int _ ->
854+
Arith.(
855+
(Value.int_val y' < Value.int_val x')
856+
- (Value.int_val x' < Value.int_val y'))
857+
| Number Int32, Number Int32 ->
858+
let* f =
859+
register_import
860+
~name:"caml_int32_compare"
861+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
862+
in
863+
let* x' = Memory.unbox_int32 x' in
864+
let* y' = Memory.unbox_int32 y' in
865+
return (W.Call (f, [ x'; y' ]))
866+
| Number Nativeint, Number Nativeint ->
867+
let* f =
868+
register_import
869+
~name:"caml_nativeint_compare"
870+
(Fun (Type.primitive_type 2))
871+
in
872+
let* x' = Memory.unbox_nativeint x' in
873+
let* y' = Memory.unbox_nativeint y' in
874+
return (W.Call (f, [ x'; y' ]))
875+
| Number Int64, Number Int64 ->
876+
let* f =
877+
register_import
878+
~name:"caml_int64_compare"
879+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
880+
in
881+
let* x' = Memory.unbox_int64 x' in
882+
let* y' = Memory.unbox_int64 y' in
883+
return (W.Call (f, [ x'; y' ]))
884+
| Number Float, Number Float ->
885+
let* f =
886+
register_import
887+
~name:"caml_float_compare"
888+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
889+
in
890+
let* x' = Memory.unbox_int64 x' in
891+
let* y' = Memory.unbox_int64 y' in
892+
return (W.Call (f, [ x'; y' ]))
893+
| _ ->
894+
let* f =
895+
register_import
896+
~name:"caml_compare"
897+
(Fun { W.params = [ Type.value; Type.value ]; result = [ I32 ] })
898+
in
899+
let* x' = x' in
900+
let* y' = y' in
901+
return (W.Call (f, [ x'; y' ])))
902+
| _ -> invalid_arity "caml_compare" l ~expected:2)
785903

786904
let rec translate_expr ctx context x e =
787905
match e with

compiler/lib-wasm/typing.ml

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,13 @@ let prim_type ~approx prim args =
191191
| "caml_lessthan"
192192
| "caml_lessequal"
193193
| "caml_equal"
194-
| "caml_compare" -> Int Ref
194+
| "caml_notequal"
195+
| "caml_compare" -> Int Normalized
195196
| "caml_int32_bswap" -> Number Int32
196197
| "caml_nativeint_bswap" -> Number Nativeint
197198
| "caml_int64_bswap" -> Number Int64
198-
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" -> Int Ref
199+
| "caml_int32_compare" | "caml_nativeint_compare" | "caml_int64_compare" ->
200+
Int Normalized
199201
| "caml_string_get32" -> Number Int32
200202
| "caml_string_get64" -> Number Int64
201203
| "caml_bytes_get32" -> Number Int32
@@ -206,7 +208,7 @@ let prim_type ~approx prim args =
206208
| "caml_nextafter_float" -> Number Float
207209
| "caml_classify_float" -> Int Ref
208210
| "caml_ldexp_float" | "caml_erf_float" | "caml_erfc_float" -> Number Float
209-
| "caml_float_compare" -> Int Ref
211+
| "caml_float_compare" -> Int Normalized
210212
| "caml_floatarray_unsafe_get" -> Number Float
211213
| "caml_bytes_unsafe_get"
212214
| "caml_string_unsafe_get"
@@ -419,6 +421,26 @@ let solver st =
419421
in
420422
Solver.f () g (propagate st)
421423

424+
let print_opt typ f e =
425+
match e with
426+
| Prim
427+
( Extern
428+
( "caml_greaterthan"
429+
| "caml_greaterequal"
430+
| "caml_lessthan"
431+
| "caml_lessequal"
432+
| "caml_equal"
433+
| "caml_compare" )
434+
, l ) -> (
435+
match List.map ~f:(arg_type ~approx:typ) l with
436+
| [ Int _; Int _ ]
437+
| [ Number Int32; Number Int32 ]
438+
| [ Number Int64; Number Int64 ]
439+
| [ Number Nativeint; Number Nativeint ]
440+
| [ Number Float; Number Float ] -> Format.fprintf f " OPT"
441+
| _ -> ())
442+
| _ -> ()
443+
422444
let f ~state ~info ~deadcode_sentinal p =
423445
update_deps state p;
424446
let function_parameters = mark_function_parameters p in
@@ -439,7 +461,8 @@ let f ~state ~info ~deadcode_sentinal p =
439461
Format.err_formatter
440462
(fun _ i ->
441463
match i with
442-
| Instr (Let (x, _)) -> Format.asprintf "{%a}" Domain.print (Var.Tbl.get typ x)
464+
| Instr (Let (x, e)) ->
465+
Format.asprintf "{%a}%a" Domain.print (Var.Tbl.get typ x) (print_opt typ) e
443466
| _ -> "")
444467
p);
445468
typ

runtime/js/compare.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function caml_compare_val(a, b, total) {
251251
b = b[i];
252252
}
253253
}
254-
//Provides: caml_compare (const, const)
254+
//Provides: caml_compare mutable (const, const)
255255
//Requires: caml_compare_val
256256
function caml_compare(a, b) {
257257
return caml_compare_val(a, b, true);

runtime/wasm/compare.wat

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -556,53 +556,49 @@
556556
(i32.const 0))
557557

558558
(func (export "caml_compare")
559-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
559+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
560560
(local $res i32)
561561
(local.set $res
562562
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 1)))
563563
(if (i32.lt_s (local.get $res) (i32.const 0))
564-
(then (return (ref.i31 (i32.const -1)))))
564+
(then (return (i32.const -1))))
565565
(if (i32.gt_s (local.get $res) (i32.const 0))
566-
(then (return (ref.i31 (i32.const 1)))))
567-
(ref.i31 (i32.const 0)))
566+
(then (return (i32.const 1))))
567+
(i32.const 0))
568568

569569
(func (export "caml_equal")
570-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
571-
(ref.i31
572-
(i32.eqz
573-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
570+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
571+
(i32.eqz
572+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
574573

575574
(func (export "caml_notequal")
576-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
577-
(ref.i31
578-
(i32.ne (i32.const 0)
579-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
575+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
576+
(i32.ne (i32.const 0)
577+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
580578

581579
(func (export "caml_lessthan")
582-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
580+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
583581
(local $res i32)
584582
(local.set $res
585583
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
586-
(ref.i31
587-
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
588-
(i32.ne (local.get $res) (global.get $unordered)))))
584+
(i32.and (i32.lt_s (local.get $res) (i32.const 0))
585+
(i32.ne (local.get $res) (global.get $unordered))))
589586

590587
(func (export "caml_lessequal")
591-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
588+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
592589
(local $res i32)
593590
(local.set $res
594591
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))
595-
(ref.i31
596-
(i32.and (i32.le_s (local.get $res) (i32.const 0))
597-
(i32.ne (local.get $res) (global.get $unordered)))))
592+
(i32.and (i32.le_s (local.get $res) (i32.const 0))
593+
(i32.ne (local.get $res) (global.get $unordered))))
598594

599595
(func (export "caml_greaterthan")
600-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
601-
(ref.i31 (i32.lt_s (i32.const 0)
602-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
596+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
597+
(i32.lt_s (i32.const 0)
598+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
603599

604600
(func (export "caml_greaterequal")
605-
(param $v1 (ref eq)) (param $v2 (ref eq)) (result (ref eq))
606-
(ref.i31 (i32.le_s (i32.const 0)
607-
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0)))))
601+
(param $v1 (ref eq)) (param $v2 (ref eq)) (result i32)
602+
(i32.le_s (i32.const 0)
603+
(call $compare_val (local.get $v1) (local.get $v2) (i32.const 0))))
608604
)

runtime/wasm/float.wat

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,13 +1132,12 @@
11321132
(struct.new $float (local.get $y)))
11331133

11341134
(func (export "caml_float_compare")
1135-
(param $x f64) (param $y f64) (result (ref eq))
1136-
(ref.i31
1137-
(i32.add
1138-
(i32.sub (f64.gt (local.get $x) (local.get $y))
1139-
(f64.lt (local.get $x) (local.get $y)))
1140-
(i32.sub (f64.eq (local.get $x) (local.get $x))
1141-
(f64.eq (local.get $y) (local.get $y))))))
1135+
(param $x f64) (param $y f64) (result i32)
1136+
(i32.add
1137+
(i32.sub (f64.gt (local.get $x) (local.get $y))
1138+
(f64.lt (local.get $x) (local.get $y)))
1139+
(i32.sub (f64.eq (local.get $x) (local.get $x))
1140+
(f64.eq (local.get $y) (local.get $y)))))
11421141

11431142
(func (export "caml_round") (param $x f64) (result f64)
11441143
(local $y f64)

runtime/wasm/int32.wat

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@
126126

127127
(export "caml_nativeint_compare" (func $caml_int32_compare))
128128
(func $caml_int32_compare (export "caml_int32_compare")
129-
(param $i1 i32) (param $i2 i32) (result (ref eq))
130-
(ref.i31 (i32.sub (i32.gt_s (local.get $i1) (local.get $i2))
131-
(i32.lt_s (local.get $i1) (local.get $i2)))))
129+
(param $i1 i32) (param $i2 i32) (result i32)
130+
(i32.sub (i32.gt_s (local.get $i1) (local.get $i2))
131+
(i32.lt_s (local.get $i1) (local.get $i2))))
132132

133133
(global $nativeint_ops (export "nativeint_ops") (ref $custom_operations)
134134
(struct.new $custom_operations

runtime/wasm/int64.wat

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@
124124
(i64.const 8)))))
125125

126126
(func (export "caml_int64_compare")
127-
(param $i1 i64) (param $i2 i64) (result (ref eq))
128-
(ref.i31 (i32.sub (i64.gt_s (local.get $i1) (local.get $i2))
129-
(i64.lt_s (local.get $i1) (local.get $i2)))))
127+
(param $i1 i64) (param $i2 i64) (result i32)
128+
(i32.sub (i64.gt_s (local.get $i1) (local.get $i2))
129+
(i64.lt_s (local.get $i1) (local.get $i2))))
130130

131131
(@string $INT64_ERRMSG "Int64.of_string")
132132

0 commit comments

Comments
 (0)