Skip to content

Commit 4142b2f

Browse files
committed
Bigarrays
1 parent 875e0df commit 4142b2f

File tree

5 files changed

+285
-8
lines changed

5 files changed

+285
-8
lines changed

compiler/lib-wasm/gc_target.ml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,37 @@ module Type = struct
419419
}
420420
])
421421
})
422+
423+
let int_array_type =
424+
register_type "int_array" (fun () ->
425+
return
426+
{ supertype = None
427+
; final = true
428+
; typ = W.Array { mut = true; typ = Value I32 }
429+
})
430+
431+
let bigarray_type =
432+
register_type "bigarray" (fun () ->
433+
let* custom_operations = custom_operations_type in
434+
let* int_array = int_array_type in
435+
let* custom = custom_type in
436+
return
437+
{ supertype = Some custom
438+
; final = true
439+
; typ =
440+
W.Struct
441+
[ { mut = false
442+
; typ = Value (Ref { nullable = false; typ = Type custom_operations })
443+
}
444+
; { mut = true; typ = Value (Ref { nullable = false; typ = Extern }) }
445+
; { mut = false
446+
; typ = Value (Ref { nullable = false; typ = Type int_array })
447+
}
448+
; { mut = false; typ = Packed I8 }
449+
; { mut = false; typ = Packed I8 }
450+
; { mut = false; typ = Packed I8 }
451+
]
452+
})
422453
end
423454

424455
module Value = struct
@@ -1354,6 +1385,56 @@ module Math = struct
13541385
let exp2 x = power (return (W.Const (F64 2.))) x
13551386
end
13561387

1388+
module Bigarray = struct
1389+
let dim1 a =
1390+
let* ty = Type.bigarray_type in
1391+
Memory.wasm_array_get
1392+
~ty:Type.int_array_type
1393+
(Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 2)
1394+
(Arith.const 0l)
1395+
1396+
let get ~kind a i =
1397+
match kind with
1398+
| Typing.Bigarray.Int8_unsigned | Char ->
1399+
let* f =
1400+
register_import
1401+
~import_module:"bindings"
1402+
~name:"ta_get_ui8"
1403+
(Fun
1404+
{ W.params = [ Ref { nullable = false; typ = Extern }; I32 ]
1405+
; result = [ I32 ]
1406+
})
1407+
in
1408+
let* ty = Type.bigarray_type in
1409+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1410+
let* i = Value.int_val i in
1411+
Value.val_int (return (W.Call (f, [ ta; i ])))
1412+
| _ -> assert false
1413+
1414+
let set ~kind a i v =
1415+
match kind with
1416+
| Typing.Bigarray.Int8_unsigned | Char ->
1417+
let* f =
1418+
register_import
1419+
~import_module:"bindings"
1420+
~name:"ta_set_ui8"
1421+
(Fun
1422+
{ W.params =
1423+
[ Ref { nullable = false; typ = Extern }
1424+
; I32
1425+
; Ref { nullable = false; typ = I31 }
1426+
]
1427+
; result = []
1428+
})
1429+
in
1430+
let* ty = Type.bigarray_type in
1431+
let* ta = Memory.wasm_struct_get ty (Memory.wasm_cast ty a) 1 in
1432+
let* i = Value.int_val i in
1433+
let* v = cast I31 v in
1434+
instr (W.CallInstr (f, [ ta; i; v ]))
1435+
| _ -> assert false
1436+
end
1437+
13571438
module JavaScript = struct
13581439
let anyref = W.Ref { nullable = true; typ = Any }
13591440

compiler/lib-wasm/generate.ml

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,48 @@ module Generate (Target : Target_sig.S) = struct
882882
let* x' = x' in
883883
let* y' = y' in
884884
return (W.Call (f, [ x'; y' ])))
885-
| _ -> invalid_arity "caml_compare" l ~expected:2)
885+
| _ -> invalid_arity "caml_compare" l ~expected:2);
886+
register_prim "caml_ba_get_1" `Mutator (fun ctx context l ->
887+
match l with
888+
| [ x; y ] -> (
889+
let x' = transl_prim_arg ctx x in
890+
let y' = transl_prim_arg ctx y in
891+
match get_type ctx x with
892+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
893+
seq
894+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
895+
instr (W.Br_if (label_index context bound_error_pc, cond)))
896+
(Bigarray.get ~kind x' y')
897+
| _ ->
898+
let* f =
899+
register_import ~name:"caml_ba_get_1" (Fun (Type.primitive_type 2))
900+
in
901+
let* x' = x' in
902+
let* y' = y' in
903+
return (W.Call (f, [ x'; y' ])))
904+
| _ -> invalid_arity "caml_ba_get_1" l ~expected:2);
905+
register_prim "caml_ba_set_1" `Mutator (fun ctx context l ->
906+
match l with
907+
| [ x; y; z ] -> (
908+
let x' = transl_prim_arg ctx x in
909+
let y' = transl_prim_arg ctx y in
910+
let z' = transl_prim_arg ctx z in
911+
match get_type ctx x with
912+
| Bigarray { kind = (Int8_unsigned | Char) as kind; layout = C } ->
913+
seq
914+
(let* cond = Arith.uge (Value.int_val y') (Bigarray.dim1 x') in
915+
let* () = instr (W.Br_if (label_index context bound_error_pc, cond)) in
916+
Bigarray.set ~kind x' y' z')
917+
Value.unit
918+
| _ ->
919+
let* f =
920+
register_import ~name:"caml_ba_set_1" (Fun (Type.primitive_type 3))
921+
in
922+
let* x' = x' in
923+
let* y' = y' in
924+
let* z' = z' in
925+
return (W.Call (f, [ x'; y'; z' ])))
926+
| _ -> invalid_arity "caml_ba_set_1" l ~expected:3)
886927

887928
let rec translate_expr ctx context x e =
888929
match e with
@@ -1147,7 +1188,9 @@ module Generate (Target : Target_sig.S) = struct
11471188
| "caml_bytes_set"
11481189
| "caml_check_bound"
11491190
| "caml_check_bound_gen"
1150-
| "caml_check_bound_float" )
1191+
| "caml_check_bound_float"
1192+
| "caml_ba_get_1"
1193+
| "caml_ba_set_1" )
11511194
, _ ) ) -> fst n, true
11521195
| Let
11531196
( _

compiler/lib-wasm/target_sig.ml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,19 @@ module type S = sig
253253
val round : expression -> expression
254254
end
255255

256+
module Bigarray : sig
257+
val dim1 : expression -> expression
258+
259+
val get : kind:Typing.Bigarray.kind -> expression -> expression -> expression
260+
261+
val set :
262+
kind:Typing.Bigarray.kind
263+
-> expression
264+
-> expression
265+
-> expression
266+
-> unit Code_generation.t
267+
end
268+
256269
val internal_primitives :
257270
(string
258271
* Primitive.kind

compiler/lib-wasm/typing.ml

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,90 @@ type boxed_number =
3131
| Nativeint
3232
| Float
3333

34+
module Bigarray = struct
35+
type kind =
36+
| Float32
37+
| Float64
38+
| Int8_signed
39+
| Int8_unsigned
40+
| Int16_signed
41+
| Int16_unsigned
42+
| Int32
43+
| Int64
44+
| Int
45+
| Nativeint
46+
| Complex32
47+
| Complex64
48+
| Char
49+
| Float16
50+
51+
type layout =
52+
| C
53+
| Fortran
54+
55+
type t =
56+
{ kind : kind
57+
; layout : layout
58+
}
59+
60+
let make ~kind ~layout =
61+
{ kind =
62+
(match kind with
63+
| 0 -> Float32
64+
| 1 -> Float64
65+
| 2 -> Int8_signed
66+
| 3 -> Int8_unsigned
67+
| 4 -> Int16_signed
68+
| 5 -> Int16_unsigned
69+
| 6 -> Int32
70+
| 7 -> Int64
71+
| 8 -> Int
72+
| 9 -> Nativeint
73+
| 10 -> Complex32
74+
| 11 -> Complex64
75+
| 12 -> Char
76+
| 13 -> Float16
77+
| _ -> assert false)
78+
; layout =
79+
(match layout with
80+
| 0 -> C
81+
| 1 -> Fortran
82+
| _ -> assert false)
83+
}
84+
85+
let print f { kind; layout } =
86+
Format.fprintf
87+
f
88+
"bigarray{%s,%s}"
89+
(match kind with
90+
| Float32 -> "float32"
91+
| Float64 -> "float64"
92+
| Int8_signed -> "sint8"
93+
| Int8_unsigned -> "uint8"
94+
| Int16_signed -> "sint16"
95+
| Int16_unsigned -> "uint16"
96+
| Int32 -> "int32"
97+
| Int64 -> "int64"
98+
| Int -> "int"
99+
| Nativeint -> "nativeint"
100+
| Complex32 -> "complex32"
101+
| Complex64 -> "complex64"
102+
| Char -> "char"
103+
| Float16 -> "float16")
104+
(match layout with
105+
| C -> "C"
106+
| Fortran -> "Fortran")
107+
108+
let equal { kind; layout } { kind = kind'; layout = layout' } =
109+
phys_equal kind kind' && phys_equal layout layout'
110+
end
111+
34112
type typ =
35113
| Top
36114
| Int of Integer.kind
37115
| Number of boxed_number
38116
| Tuple of typ array
117+
| Bigarray of Bigarray.t
39118
| Bot
40119

41120
module Domain = struct
@@ -55,8 +134,9 @@ module Domain = struct
55134
else
56135
Array.init (max l l') ~f:(fun i ->
57136
if i < l then if i < l' then join t.(i) t'.(i) else t.(i) else t'.(i)))
137+
| Bigarray b, Bigarray b' when Bigarray.equal b b' -> t
58138
| Top, _ | _, Top -> Top
59-
| (Int _ | Number _ | Tuple _), _ -> Top
139+
| (Int _ | Number _ | Tuple _ | Bigarray _), _ -> Top
60140

61141
let join_set ?(others = false) f s =
62142
if others then Top else Var.Set.fold (fun x a -> join (f x) a) s Bot
@@ -68,7 +148,8 @@ module Domain = struct
68148
| Number t, Number t' -> Poly.equal t t'
69149
| Tuple t, Tuple t' ->
70150
Array.length t = Array.length t' && Array.for_all2 ~f:equal t t'
71-
| (Top | Tuple _ | Int _ | Number _ | Bot), _ -> false
151+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
152+
| (Top | Tuple _ | Int _ | Number _ | Bigarray _ | Bot), _ -> false
72153

73154
let rec sub t t' =
74155
match t, t' with
@@ -83,20 +164,21 @@ module Domain = struct
83164
i = l || (sub t.(i) t'.(i) && compare t t' (i + 1) l)
84165
in
85166
compare t t' 0 (Array.length t)
86-
| (Int _ | Number _ | Tuple _), _ -> false
167+
| Bigarray b, Bigarray b' -> Bigarray.equal b b'
168+
| (Int _ | Number _ | Tuple _ | Bigarray _), _ -> false
87169

88170
let bot = Bot
89171

90172
let depth_treshold = 4
91173

92174
let rec depth t =
93175
match t with
94-
| Top | Bot | Number _ | Int _ -> 0
176+
| Top | Bot | Number _ | Int _ | Bigarray _ -> 0
95177
| Tuple l -> 1 + Array.fold_left ~f:(fun acc t' -> max (depth t') acc) l ~init:0
96178

97179
let rec truncate depth t =
98180
match t with
99-
| Top | Bot | Number _ | Int _ -> t
181+
| Top | Bot | Number _ | Int _ | Bigarray _ -> t
100182
| Tuple l ->
101183
if depth = 0
102184
then Top
@@ -125,6 +207,7 @@ module Domain = struct
125207
| Number Int64 -> Format.fprintf f "int64"
126208
| Number Nativeint -> Format.fprintf f "nativeint"
127209
| Number Float -> Format.fprintf f "float"
210+
| Bigarray b -> Bigarray.print f b
128211
| Tuple t ->
129212
Format.fprintf
130213
f
@@ -412,7 +495,32 @@ let propagate st approx x : Domain.t =
412495
when List.length args = List.length params ->
413496
Domain.box
414497
(Domain.join_set
415-
(fun y -> Var.Tbl.get approx y)
498+
(fun y ->
499+
match st.state.defs.(Var.idx y) with
500+
| Expr
501+
(Prim (Extern "caml_ba_create", [ Pv kind; Pv layout; _ ]))
502+
-> (
503+
let m =
504+
List.fold_left2
505+
~f:(fun m p a -> Var.Map.add p a m)
506+
~init:Var.Map.empty
507+
params
508+
args
509+
in
510+
try
511+
match
512+
( st.state.defs.(Var.idx (Var.Map.find kind m))
513+
, st.state.defs.(Var.idx (Var.Map.find layout m)) )
514+
with
515+
| ( Expr (Constant (Int kind))
516+
, Expr (Constant (Int layout)) ) ->
517+
Bigarray
518+
(Bigarray.make
519+
~kind:(Targetint.to_int_exn kind)
520+
~layout:(Targetint.to_int_exn layout))
521+
| _ -> raise Not_found
522+
with Not_found -> Var.Tbl.get approx y)
523+
| _ -> Var.Tbl.get approx y)
416524
(Var.Map.find g st.state.return_values))
417525
| Expr (Closure (_, _, _)) ->
418526
(* The function is partially applied or over applied *)
@@ -471,6 +579,10 @@ let print_opt typ f e =
471579
; Number Float
472580
]
473581
then Format.fprintf f " OPT"
582+
| Prim (Extern ("caml_ba_get_1" | "caml_ba_set_1"), Pv x :: _) -> (
583+
match Var.Tbl.get typ x with
584+
| Bigarray _ -> Format.fprintf f " OPT"
585+
| _ -> ())
474586
| _ -> ()
475587

476588
let f ~state ~info ~deadcode_sentinal p =

0 commit comments

Comments
 (0)