Skip to content

Commit 31a0d91

Browse files
committed
Add RoundingMode sort
1 parent 3e75041 commit 31a0d91

File tree

13 files changed

+87
-53
lines changed

13 files changed

+87
-53
lines changed

src/smtml/bitwuzla_mappings.default.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ module Fresh_bitwuzla (B : Bitwuzla_cxx.S) : M = struct
7979

8080
let float ebits sbits = mk_fp_sort ebits sbits
8181

82+
let roundingMode = mk_rm_sort ()
83+
8284
let ty t = Term.sort t
8385

8486
let to_ety _ = Fmt.failwith "Bitwuzla_mappings: to_ety not implemented"

src/smtml/cvc5_mappings.default.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ module Fresh_cvc5 () = struct
7777

7878
let float ebits sbits = Sort.mk_fp_sort tm ebits sbits
7979

80+
let roundingMode = Sort.mk_rm_sort tm
81+
8082
let ty t = Term.sort t
8183

8284
let to_ety _ = assert false

src/smtml/dolmenexpr_to_expr.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ module DolmenIntf : sig
9090

9191
val float : int -> int -> ty
9292

93+
val roundingMode : ty
94+
9395
val ty : term -> ty
9496

9597
val to_ety : ty -> Ty.t

src/smtml/eval.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,8 @@ let op int real bool str lst bv f32 f64 ty op =
10491049
| Ty_bitv _ -> bv op
10501050
| Ty_fp 32 -> f32 op
10511051
| Ty_fp 64 -> f64 op
1052-
| Ty_fp _ | Ty_app | Ty_unit | Ty_none | Ty_regexp -> assert false
1052+
| Ty_fp _ | Ty_app | Ty_unit | Ty_none | Ty_regexp | Ty_roundingMode ->
1053+
assert false
10531054
[@@inline]
10541055

10551056
let unop =

src/smtml/mappings.ml

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
6565
| Ty_bitv n -> M.Types.bitv n
6666
| Ty_fp 32 -> f32
6767
| Ty_fp 64 -> f64
68+
| Ty_roundingMode -> M.Types.roundingMode
6869
| (Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none | Ty_regexp) as ty ->
6970
Fmt.failwith "Unsupported theory: %a@." Ty.pp ty
7071

@@ -598,7 +599,8 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
598599
| Ty.Ty_bitv 64 -> I64.unop
599600
| Ty.Ty_fp 32 -> Float32_impl.unop
600601
| Ty.Ty_fp 64 -> Float64_impl.unop
601-
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none ->
602+
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none
603+
| Ty_roundingMode ->
602604
assert false
603605

604606
let binop = function
@@ -612,7 +614,8 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
612614
| Ty.Ty_bitv 64 -> I64.binop
613615
| Ty.Ty_fp 32 -> Float32_impl.binop
614616
| Ty.Ty_fp 64 -> Float64_impl.binop
615-
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none ->
617+
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none
618+
| Ty_roundingMode ->
616619
assert false
617620

618621
let triop = function
@@ -625,7 +628,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
625628
| Ty.Ty_fp 32 -> Float32_impl.triop
626629
| Ty.Ty_fp 64 -> Float64_impl.triop
627630
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none
628-
| Ty_regexp ->
631+
| Ty_regexp | Ty_roundingMode ->
629632
assert false
630633

631634
let relop = function
@@ -639,7 +642,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
639642
| Ty.Ty_fp 32 -> Float32_impl.relop
640643
| Ty.Ty_fp 64 -> Float64_impl.relop
641644
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none
642-
| Ty_regexp ->
645+
| Ty_regexp | Ty_roundingMode ->
643646
assert false
644647

645648
let cvtop = function
@@ -653,7 +656,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
653656
| Ty.Ty_fp 32 -> Float32_impl.cvtop
654657
| Ty.Ty_fp 64 -> Float64_impl.cvtop
655658
| Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none
656-
| Ty_regexp ->
659+
| Ty_regexp | Ty_roundingMode ->
657660
assert false
658661

659662
let naryop = function
@@ -662,18 +665,19 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
662665
| Ty.Ty_regexp -> Regexp_impl.naryop
663666
| ty -> Fmt.failwith "Naryop for type \"%a\" not implemented" Ty.pp ty
664667

665-
let get_rounding_mode rm =
668+
let get_rounding_mode ctx rm =
666669
match Expr.view rm with
667670
| Symbol { name = Simple ("roundNearestTiesToEven" | "RNE"); _ } ->
668-
M.Float.Rounding_mode.rne
671+
(ctx, M.Float.Rounding_mode.rne)
669672
| Symbol { name = Simple ("roundNearestTiesToAway" | "RNA"); _ } ->
670-
M.Float.Rounding_mode.rna
673+
(ctx, M.Float.Rounding_mode.rna)
671674
| Symbol { name = Simple ("roundTowardPositive" | "RTP"); _ } ->
672-
M.Float.Rounding_mode.rtp
675+
(ctx, M.Float.Rounding_mode.rtp)
673676
| Symbol { name = Simple ("roundTowardNegative" | "RTN"); _ } ->
674-
M.Float.Rounding_mode.rtn
677+
(ctx, M.Float.Rounding_mode.rtn)
675678
| Symbol { name = Simple ("roundTowardZero" | "RTZ"); _ } ->
676-
M.Float.Rounding_mode.rtz
679+
(ctx, M.Float.Rounding_mode.rtz)
680+
| Symbol rm -> make_symbol ctx rm
677681
| _ -> Fmt.failwith "unknown rouding mode: %a" Expr.pp rm
678682

679683
let rec encode_expr ctx (hte : Expr.t) : symbol_ctx * M.term =
@@ -688,36 +692,36 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
688692
| App ({ name = Simple "fp.add"; _ }, [ rm; a; b ]) ->
689693
let ctx, a = encode_expr ctx a in
690694
let ctx, b = encode_expr ctx b in
691-
let rm = get_rounding_mode rm in
695+
let ctx, rm = get_rounding_mode ctx rm in
692696
(ctx, M.Float.add ~rm a b)
693697
| App ({ name = Simple "fp.sub"; _ }, [ rm; a; b ]) ->
694698
let ctx, a = encode_expr ctx a in
695699
let ctx, b = encode_expr ctx b in
696-
let rm = get_rounding_mode rm in
700+
let ctx, rm = get_rounding_mode ctx rm in
697701
(ctx, M.Float.sub ~rm a b)
698702
| App ({ name = Simple "fp.mul"; _ }, [ rm; a; b ]) ->
699703
let ctx, a = encode_expr ctx a in
700704
let ctx, b = encode_expr ctx b in
701-
let rm = get_rounding_mode rm in
705+
let ctx, rm = get_rounding_mode ctx rm in
702706
(ctx, M.Float.mul ~rm a b)
703707
| App ({ name = Simple "fp.div"; _ }, [ rm; a; b ]) ->
704708
let ctx, a = encode_expr ctx a in
705709
let ctx, b = encode_expr ctx b in
706-
let rm = get_rounding_mode rm in
710+
let ctx, rm = get_rounding_mode ctx rm in
707711
(ctx, M.Float.div ~rm a b)
708712
| App ({ name = Simple "fp.fma"; _ }, [ rm; a; b; c ]) ->
709713
let ctx, a = encode_expr ctx a in
710714
let ctx, b = encode_expr ctx b in
711715
let ctx, c = encode_expr ctx c in
712-
let rm = get_rounding_mode rm in
716+
let ctx, rm = get_rounding_mode ctx rm in
713717
(ctx, M.Float.fma ~rm a b c)
714718
| App ({ name = Simple "fp.sqrt"; _ }, [ rm; a ]) ->
715719
let ctx, a = encode_expr ctx a in
716-
let rm = get_rounding_mode rm in
720+
let ctx, rm = get_rounding_mode ctx rm in
717721
(ctx, M.Float.sqrt ~rm a)
718722
| App ({ name = Simple "fp.roundToIntegral"; _ }, [ rm; a ]) ->
719723
let ctx, a = encode_expr ctx a in
720-
let rm = get_rounding_mode rm in
724+
let ctx, rm = get_rounding_mode ctx rm in
721725
(ctx, M.Float.round_to_integral ~rm a)
722726
| App (sym, args) ->
723727
let name =
@@ -822,7 +826,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct
822826
let float = M.Interp.to_float v 11 53 in
823827
Value.Num (F64 (Int64.bits_of_float float))
824828
| Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none | Ty_regexp
825-
->
829+
| Ty_roundingMode ->
826830
assert false
827831

828832
let value ({ model = m; ctx } : model) (c : Expr.t) : Value.t =

src/smtml/mappings.nop.ml

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ module Nop = struct
1010
let is_available = false
1111
end
1212

13-
type ty = unit
13+
type ty = [ `Ty ]
1414

15-
type term = unit
15+
type term = [ `Term ]
1616

1717
type interp
1818

@@ -24,11 +24,11 @@ module Nop = struct
2424

2525
type optimizer
2626

27-
type func_decl = unit
27+
type func_decl = [ `Func_decl ]
2828

29-
let true_ = ()
29+
let true_ = `Term
3030

31-
let false_ = ()
31+
let false_ = `Term
3232

3333
let int _ = assert false
3434

@@ -61,17 +61,19 @@ module Nop = struct
6161
let exists _ _ = assert false
6262

6363
module Types = struct
64-
let int = ()
64+
let int = `Ty
6565

66-
let real = ()
66+
let real = `Ty
6767

68-
let bool = ()
68+
let bool = `Ty
6969

70-
let string = ()
70+
let string = `Ty
7171

72-
let bitv _ = ()
72+
let bitv _ = `Ty
7373

74-
let float _ _ = ()
74+
let float _ _ = `Ty
75+
76+
let roundingMode = `Ty
7577

7678
let ty _ = assert false
7779

@@ -262,15 +264,15 @@ module Nop = struct
262264

263265
module Float = struct
264266
module Rounding_mode = struct
265-
let rne = ()
267+
let rne = `Term
266268

267-
let rna = ()
269+
let rna = `Term
268270

269-
let rtp = ()
271+
let rtp = `Term
270272

271-
let rtn = ()
273+
let rtn = `Term
272274

273-
let rtz = ()
275+
let rtz = `Term
274276
end
275277

276278
let v _ = assert false
@@ -339,9 +341,9 @@ module Nop = struct
339341
end
340342

341343
module Func = struct
342-
let make _ _ _ = ()
344+
let make _ _ _ = `Func_decl
343345

344-
let apply () _ = ()
346+
let apply `Func_decl _ = `Term
345347
end
346348

347349
module Model = struct

src/smtml/mappings_intf.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ module type M = sig
120120
significand width [s]. *)
121121
val float : int -> int -> ty
122122

123+
val roundingMode : ty
124+
123125
(** [ty t] retrieves the type of the term [t]. *)
124126
val ty : term -> ty
125127

src/smtml/rewrite.ml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@ let debug fmt k = if debug then k (Fmt.epr fmt)
2121

2222
(* FIXME: This is a very basic way to infer types. I'm surprised it even works *)
2323
let rewrite_ty unknown_ty tys =
24-
debug " rewrite_ty: %a@." (fun k -> k Ty.pp unknown_ty);
2524
match (unknown_ty, tys) with
26-
| Ty_none, [ ty ] -> ty
25+
| Ty.Ty_none, [ ty ] ->
26+
debug " rewrite_ty: %a -> %a@." (fun k -> k Ty.pp unknown_ty Ty.pp ty);
27+
ty
2728
| Ty_none, [ ty1; ty2 ] ->
28-
debug " rewrite_ty: %a %a@." (fun k -> k Ty.pp ty1 Ty.pp ty2);
29+
debug " rewrite_ty: %a -> (%a %a)@." (fun k ->
30+
k Ty.pp unknown_ty Ty.pp ty1 Ty.pp ty2 );
2931
assert (Ty.equal ty1 ty2);
3032
ty1
3133
| Ty_none, ty1 :: ty2 :: [ ty3 ] ->
32-
debug " rewrite_ty: %a %a %a@." (fun k -> k Ty.pp ty1 Ty.pp ty2 Ty.pp ty3);
34+
debug " rewrite_ty: %a ->(%a %a %a)@." (fun k ->
35+
k Ty.pp unknown_ty Ty.pp ty1 Ty.pp ty2 Ty.pp ty3 );
3336
assert (Ty.equal ty1 ty2);
3437
assert (Ty.equal ty2 ty3);
3538
ty1
@@ -55,11 +58,13 @@ let rec rewrite_expr (type_map, expr_map) hte =
5558
| App
5659
( ({ name = Simple ("fp.add" | "fp.sub" | "fp.mul" | "fp.div"); _ } as sym)
5760
, [ rm; a; b ] ) ->
61+
let rm = rewrite_expr (type_map, expr_map) rm in
5862
let a = rewrite_expr (type_map, expr_map) a in
5963
let b = rewrite_expr (type_map, expr_map) b in
6064
let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b ] in
6165
Expr.app { sym with ty } [ rm; a; b ]
6266
| App (({ name = Simple "fp.fma"; _ } as sym), [ rm; a; b; c ]) ->
67+
let rm = rewrite_expr (type_map, expr_map) rm in
6368
let a = rewrite_expr (type_map, expr_map) a in
6469
let b = rewrite_expr (type_map, expr_map) b in
6570
let c = rewrite_expr (type_map, expr_map) c in
@@ -68,6 +73,7 @@ let rec rewrite_expr (type_map, expr_map) hte =
6873
| App
6974
( ({ name = Simple ("fp.sqrt" | "fp.roundToIntegral"); _ } as sym)
7075
, [ rm; a ] ) ->
76+
let rm = rewrite_expr (type_map, expr_map) rm in
7177
let a = rewrite_expr (type_map, expr_map) a in
7278
let ty = rewrite_ty Ty_none [ Expr.ty a ] in
7379
Expr.app { sym with ty } [ rm; a ]

src/smtml/smtlib.ml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,24 @@ module Term = struct
3030

3131
let const ?loc (id : Symbol.t) : t =
3232
match (Symbol.namespace id, Symbol.name id) with
33-
| Sort, Simple name -> (
33+
| Sort, Simple name -> begin
3434
match name with
3535
| "Int" -> Expr.symbol { id with ty = Ty_int }
3636
| "Real" -> Expr.symbol { id with ty = Ty_real }
3737
| "Bool" -> Expr.symbol { id with ty = Ty_bool }
3838
| "String" -> Expr.symbol { id with ty = Ty_str }
3939
| "Float32" -> Expr.symbol { id with ty = Ty_fp 32 }
4040
| "Float64" -> Expr.symbol { id with ty = Ty_fp 64 }
41-
| _ -> (
41+
| "RoundingMode" -> Expr.symbol { id with ty = Ty_roundingMode }
42+
| _ -> begin
4243
match Hashtbl.find_opt custom_sorts name with
44+
| Some ty -> Expr.symbol { id with ty }
4345
| None ->
44-
Fmt.failwith "%acould not find sort: %a" pp_loc loc Symbol.pp id
45-
| Some ty -> Expr.symbol { id with ty } ) )
46+
Logs.err (fun k ->
47+
k "%acould not find sort: %a" pp_loc loc Symbol.pp id );
48+
Expr.symbol id
49+
end
50+
end
4651
| Sort, Indexed { basename; indices } -> (
4752
match (basename, indices) with
4853
| "BitVec", [ n ] -> (
@@ -78,7 +83,7 @@ module Term = struct
7883
fp_of_size (Float.neg Float.zero) ebits sbits
7984
| "NaN", [ ebits; sbits ] -> fp_of_size Float.nan ebits sbits
8085
| _ ->
81-
Log.debug (fun k -> k "const: Unknown %a making app" Symbol.pp id);
86+
Log.debug (fun k -> k "const: unknown %a making app" Symbol.pp id);
8287
Expr.symbol id
8388
end
8489
| Attr, Simple _ -> Expr.symbol id
@@ -90,12 +95,12 @@ module Term = struct
9095
let int ?loc (x : string) =
9196
match int_of_string_opt x with
9297
| Some x -> Expr.value (Int x)
93-
| None -> Fmt.failwith "%aInvalid int" pp_loc loc
98+
| None -> Fmt.failwith "%ainvalid int" pp_loc loc
9499

95100
let real ?loc (x : string) =
96101
match float_of_string_opt x with
97102
| Some x -> Expr.value (Real x)
98-
| None -> Fmt.failwith "%aInvalid real" pp_loc loc
103+
| None -> Fmt.failwith "%ainvalid real" pp_loc loc
99104

100105
let hexa ?loc:_ (h : string) =
101106
let len = String.length h in

0 commit comments

Comments
 (0)