diff --git a/README.md b/README.md index 9c327b6d..b1511fb0 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,8 @@ module Z3 : val solver : Z3.t = # let cond = - let a = Expr.symbol (Symbol.make Ty.Ty_bool "a") in - let b = Expr.symbol (Symbol.make Ty.Ty_bool "b") in + let a = Expr.symbol (Symbol.make (Ty.Ty Ty_bool) "a") in + let b = Expr.symbol (Symbol.make (Ty.Ty Ty_bool) "b") in Expr.(binop Ty_bool And a (unop Ty_bool Not b));; val cond : Expr.t = (bool.and a (bool.not b)) diff --git a/doc/examples.mld b/doc/examples.mld index 30e25c79..327b2ca4 100644 --- a/doc/examples.mld +++ b/doc/examples.mld @@ -72,7 +72,7 @@ let read_int () = Scanf.scanf " %d" (fun x -> x) let int x = Expr.value (Int x) -let symbol x = Expr.symbol Symbol.(x @: Ty_int) +let symbol x = Expr.symbol Symbol.(x @: Ty Ty_int) let ( = ) i1 i2 = Expr.relop Ty_bool Eq i1 i2 diff --git a/doc/examples/product_mix.ml b/doc/examples/product_mix.ml index 049a7ccd..46159aa0 100644 --- a/doc/examples/product_mix.ml +++ b/doc/examples/product_mix.ml @@ -5,7 +5,7 @@ let read_int () = Scanf.scanf " %d" (fun x -> x) let int x = Expr.value (Int x) -let symbol x = Expr.symbol Symbol.(x @: Ty_int) +let symbol x = Expr.symbol Symbol.(x @: Ty Ty_int) let ( = ) i1 i2 = Expr.relop Ty_bool Eq i1 i2 diff --git a/doc/index.mld b/doc/index.mld index 2e46a364..5efe6779 100644 --- a/doc/index.mld +++ b/doc/index.mld @@ -93,8 +93,8 @@ Example: Bitvector arithmetic {@ocaml[ # open Smtml;; # let cond = - let x = Expr.symbol (Symbol.make (Ty_bitv 8) "x") in - let y = Expr.symbol (Symbol.make (Ty_bitv 8) "y") in + let x = Expr.symbol (Symbol.make (Ty (Ty_bitv 8)) "x") in + let y = Expr.symbol (Symbol.make (Ty (Ty_bitv 8)) "y") in let sum = Expr.binop (Ty_bitv 8) Add x y in let num = Expr.value (Bitv (Bitvector.of_int8 42)) in Expr.relop Ty_bool Eq sum num;; @@ -123,7 +123,7 @@ val model : Model.t = (model (x i8 9) (y i8 33)) # let x_val = - let x = Symbol.make (Ty_bitv 8) "x" in + let x = Symbol.make (Ty (Ty_bitv 8)) "x" in Model.evaluate model x;; val x_val : Value.t option = Some (Smtml.Value.Bitv ) ]} diff --git a/src/smtml/altergo_mappings.default.ml b/src/smtml/altergo_mappings.default.ml index 16b41605..d762e612 100644 --- a/src/smtml/altergo_mappings.default.ml +++ b/src/smtml/altergo_mappings.default.ml @@ -221,10 +221,10 @@ module M = struct module Model = struct let aety_to_ty (ty : AEL.Ty.t) : Ty.t = match ty with - | Tbool -> Ty_bool - | Tint -> Ty_int - | Treal -> Ty_real - | Tbitv n -> Ty_bitv n + | Tbool -> Ty Ty_bool + | Tint -> Ty Ty_int + | Treal -> Ty Ty_real + | Tbitv n -> Ty (Ty_bitv n) | _ -> assert false let aeid_to_sym ((hs, tyl, ty) : AEL.Id.typed) = diff --git a/src/smtml/dolmenexpr_to_expr.ml b/src/smtml/dolmenexpr_to_expr.ml index c91049c9..50acf610 100644 --- a/src/smtml/dolmenexpr_to_expr.ml +++ b/src/smtml/dolmenexpr_to_expr.ml @@ -150,9 +150,11 @@ module DolmenIntf = struct let to_ety (ty : DTy.t) : Ty.t = match ty with - | { ty_descr = TyApp ({ builtin = DBuiltin.Int; _ }, _); _ } -> Ty_int - | { ty_descr = TyApp ({ builtin = DBuiltin.Real; _ }, _); _ } -> Ty_real - | { ty_descr = TyApp ({ builtin = DBuiltin.Prop; _ }, _); _ } -> Ty_bool + | { ty_descr = TyApp ({ builtin = DBuiltin.Int; _ }, _); _ } -> Ty Ty_int + | { ty_descr = TyApp ({ builtin = DBuiltin.Real; _ }, _); _ } -> + Ty Ty_real + | { ty_descr = TyApp ({ builtin = DBuiltin.Prop; _ }, _); _ } -> + Ty Ty_bool | { ty_descr = TyApp ( { builtin = DBuiltin.Base @@ -162,13 +164,13 @@ module DolmenIntf = struct , _ ) ; _ } -> - Ty_str + Ty Ty_str | { ty_descr = TyApp ({ builtin = DBuiltin.Bitv n; _ }, _); _ } -> - Ty_bitv n + Ty (Ty_bitv n) | { ty_descr = TyApp ({ builtin = DBuiltin.Float (8, 24); _ }, _); _ } -> - Ty_fp 32 + Ty (Ty_fp 32) | { ty_descr = TyApp ({ builtin = DBuiltin.Float (11, 53); _ }, _); _ } -> - Ty_fp 64 + Ty (Ty_fp 64) | _ -> Fmt.failwith {|Unsupported dolmen type "%a"|} DTy.print ty end diff --git a/src/smtml/dune b/src/smtml/dune index 66d05dc5..658e955e 100644 --- a/src/smtml/dune +++ b/src/smtml/dune @@ -51,6 +51,7 @@ symbol ty utils + utils_parse value z3_mappings) (private_modules lexer parser) diff --git a/src/smtml/eval.ml b/src/smtml/eval.ml index ba3fc1fd..550721ec 100644 --- a/src/smtml/eval.ml +++ b/src/smtml/eval.ml @@ -48,6 +48,7 @@ exception Integer_overflow (* FIXME: use snake case instead *) exception Index_out_of_bounds +(* Helpers *) let of_arg f n v op msg = try f v with Value t -> raise (TypeError { index = n; value = v; ty = t; op; msg }) @@ -57,31 +58,39 @@ let err_str n op ty_expected ty_actual = Fmt.str "Argument %d of %a expected type %a but got %a instead." n pp_op_type op Ty.pp ty_expected Ty.pp ty_actual +let is_float fpclass f = + match (Float.classify_float f, fpclass) with + | FP_normal, FP_normal + | FP_subnormal, FP_subnormal + | FP_infinite, FP_infinite + | FP_nan, FP_nan + | FP_zero, FP_zero -> + true + | (FP_normal | FP_subnormal | FP_infinite | FP_nan | FP_zero), _ -> false + module Int = struct let to_value (i : int) : Value.t = Int i [@@inline] let of_value (n : int) (op : op_type) (v : Value.t) : int = of_arg - (function Int i -> i | _ -> raise_notrace (Value Ty_int)) + (function Int i -> i | _ -> raise_notrace (Value (Ty Ty_int))) n v op - (err_str n op Ty_int (Value.type_of v)) + (err_str n op (Ty Ty_int) (Value.type_of v)) [@@inline] let str_value (n : int) (op : op_type) (v : Value.t) : string = of_arg - (function Str str -> str | _ -> raise_notrace (Value Ty_str)) + (function Str str -> str | _ -> raise_notrace (Value (Ty Ty_str))) n v op - (err_str n op Ty_str (Value.type_of v)) + (err_str n op (Ty Ty_str) (Value.type_of v)) - let unop (op : Ty.Unop.t) (v : Value.t) : Value.t = + let unop (op : [ `Ty_int ] Ty.Unop.op) (v : Value.t) : Value.t = let f = - match op with - | Neg -> Int.neg - | Not -> Int.lognot - | Abs -> Int.abs - | _ -> Fmt.failwith {|unop: Unsupported int operator "%a"|} Ty.Unop.pp op + match op with Neg -> Int.neg | Not -> Int.lognot | Abs -> Int.abs + (* | _ -> *) + (* Fmt.failwith {|unop: Unsupported int operator "%a"|} Ty.Unop.pp (U op) *) in - to_value (f (of_value 1 (`Unop op) v)) + to_value (f (of_value 1 (`Unop (U op)) v)) let exp_by_squaring x n = let rec exp_by_squaring2 y x n = @@ -155,13 +164,13 @@ module Real = struct let of_value (n : int) (op : op_type) (v : Value.t) : float = of_arg - (function Real v -> v | _ -> raise_notrace (Value Ty_int)) + (function Real v -> v | _ -> raise_notrace (Value (Ty Ty_int))) n v op - (err_str n op Ty_real (Value.type_of v)) + (err_str n op (Ty Ty_real) (Value.type_of v)) [@@inline] - let unop (op : Ty.Unop.t) (v : Value.t) : Value.t = - let v = of_value 1 (`Unop op) v in + let unop (op : [ `Ty_real ] Ty.Unop.op) (v : Value.t) : Value.t = + let v = of_value 1 (`Unop (U op)) v in match op with | Neg -> to_value @@ Float.neg v | Abs -> to_value @@ Float.abs v @@ -171,7 +180,7 @@ module Real = struct | Floor -> to_value @@ Float.floor v | Trunc -> to_value @@ Float.trunc v | Is_nan -> if Float.is_nan v then Value.True else Value.False - | _ -> Fmt.failwith {|unop: Unsupported real operator "%a"|} Ty.Unop.pp op + (* | _ -> Fmt.failwith {|unop: Unsupported real operator "%a"|} Ty.Unop.pp op *) let binop (op : Ty.Binop.t) (v1 : Value.t) (v2 : Value.t) : Value.t = let f = @@ -208,14 +217,18 @@ module Real = struct match op with | ToString -> Str (Float.to_string (of_value 1 op' v)) | OfString -> - let v = match v with Str v -> v | _ -> raise_notrace (Value Ty_str) in + let v = + match v with Str v -> v | _ -> raise_notrace (Value (Ty Ty_str)) + in begin match Float.of_string_opt v with | None -> raise (Invalid_argument "float_of_int") | Some v -> to_value v end | Reinterpret_int -> - let v = match v with Int v -> v | _ -> raise_notrace (Value Ty_int) in + let v = + match v with Int v -> v | _ -> raise_notrace (Value (Ty Ty_int)) + in to_value (float_of_int v) | Reinterpret_float -> Int (Float.to_int (of_value 1 op' v)) | _ -> Fmt.failwith {|cvtop: Unsupported real operator "%a"|} Ty.Cvtop.pp op @@ -227,16 +240,14 @@ module Bool = struct let of_value (n : int) (op : op_type) (v : Value.t) : bool = of_arg (function - | True -> true | False -> false | _ -> raise_notrace (Value Ty_bool) ) + | True -> true | False -> false | _ -> raise_notrace (Value (Ty Ty_bool)) ) n v op - (err_str n op Ty_bool (Value.type_of v)) + (err_str n op (Ty Ty_bool) (Value.type_of v)) [@@inline] - let unop (op : Ty.Unop.t) v = - let b = of_value 1 (`Unop op) v in - match op with - | Not -> to_value (not b) - | _ -> Fmt.failwith {|unop: Unsupported bool operator "%a"|} Ty.Unop.pp op + let unop (op : [ `Ty_bool ] Ty.Unop.op) v = + let b = of_value 1 (`Unop (U op)) v in + match op with Not -> to_value (not b) let xor b1 b2 = match (b1, b2) with @@ -287,9 +298,9 @@ module Str = struct let of_value (n : int) (op : op_type) (v : Value.t) : string = of_arg - (function Str str -> str | _ -> raise_notrace (Value Ty_str)) + (function Str str -> str | _ -> raise_notrace (Value (Ty Ty_str))) n v op - (err_str n op Ty_str (Value.type_of v)) + (err_str n op (Ty Ty_str) (Value.type_of v)) [@@inline] let replace s t t' = @@ -320,12 +331,13 @@ module Str = struct let contains s sub = if indexof s sub 0 < 0 then false else true - let unop (op : Ty.Unop.t) v = - let str = of_value 1 (`Unop op) v in + let unop (op : [ `Ty_str ] Ty.Unop.op) v = + let str = of_value 1 (`Unop (U op)) v in match op with | Length -> Int.to_value (String.length str) | Trim -> to_value (String.trim str) - | _ -> Fmt.failwith {|unop: Unsupported str operator "%a"|} Ty.Unop.pp op + | Regexp_star | Regexp_loop _ | Regexp_plus | Regexp_opt | Regexp_comp -> + Fmt.failwith {|unop: Unsupported str operator "%a"|} Ty.Unop.pp (U op) let binop (op : Ty.Binop.t) v1 v2 = let op' = `Binop op in @@ -414,13 +426,13 @@ end module Lst = struct let of_value (n : int) (op : op_type) (v : Value.t) : Value.t list = of_arg - (function List lst -> lst | _ -> raise_notrace (Value Ty_list)) + (function List lst -> lst | _ -> raise_notrace (Value (Ty Ty_list))) n v op - (err_str n op Ty_list (Value.type_of v)) + (err_str n op (Ty Ty_list) (Value.type_of v)) [@@inline] - let unop (op : Ty.Unop.t) (v : Value.t) : Value.t = - let lst = of_value 1 (`Unop op) v in + let unop (op : [ `Ty_list ] Ty.Unop.op) (v : Value.t) : Value.t = + let lst = of_value 1 (`Unop (U op)) v in match op with | Head -> begin match lst with hd :: _tl -> hd | [] -> assert false end | Tail -> begin @@ -428,7 +440,6 @@ module Lst = struct end | Length -> Int.to_value (List.length lst) | Reverse -> List (List.rev lst) - | _ -> Fmt.failwith {|unop: Unsupported list operator "%a"|} Ty.Unop.pp op let binop (op : Ty.Binop.t) v1 v2 = let op' = `Binop op in @@ -484,7 +495,7 @@ module Bitv = struct let i64_to_value v = to_value @@ Bitvector.of_int64 v let of_value (n : int) (op : op_type) (v : Value.t) : Bitvector.t = - let todo = Ty.Ty_bitv 32 in + let todo = Ty.Ty (Ty_bitv 32) in of_arg (function Bitv bv -> bv | _ -> raise_notrace (Value todo)) n v op @@ -494,18 +505,16 @@ module Bitv = struct let i64_of_value n op v = of_value n op v |> Bitvector.to_int64 - let unop op bv = - let bv = of_value 1 (`Unop op) bv in + let unop (op : [ `Ty_bitv ] Ty.Unop.op) bv = + let bv = of_value 1 (`Unop (U op)) bv in to_value @@ match op with - | Ty.Unop.Neg -> Bitvector.neg bv + | Neg -> Bitvector.neg bv | Not -> Bitvector.lognot bv | Clz -> Bitvector.clz bv | Ctz -> Bitvector.ctz bv | Popcnt -> Bitvector.popcnt bv - | _ -> - Fmt.failwith {|unop: Unsupported bitvectore operator "%a"|} Ty.Unop.pp op let binop op bv1 bv2 = let bv1 = of_value 1 (`Binop op) bv1 in @@ -558,9 +567,9 @@ module F32 = struct let of_value (i : int) (op : op_type) (v : Value.t) : int32 = of_arg - (function Num (F32 f) -> f | _ -> raise_notrace (Value (Ty_fp 32))) + (function Num (F32 f) -> f | _ -> raise_notrace (Value (Ty (Ty_fp 32)))) i v op - (err_str i op (Ty_fp 32) (Value.type_of v)) + (err_str i op (Ty (Ty_fp 32)) (Value.type_of v)) [@@inline] let of_value' (i : int) (op : op_type) (v : Value.t) : float = @@ -572,18 +581,23 @@ module F32 = struct let neg x = Int32.logxor x Int32.min_int - let unop (op : Ty.Unop.t) (v : Value.t) : Value.t = - let f = to_float @@ of_value 1 (`Unop op) v in + let unop (op : [ `Ty_fp ] Ty.Unop.op) (v : Value.t) : Value.t = + let f = to_float @@ of_value 1 (`Unop (U op)) v in match op with - | Neg -> to_value @@ neg @@ of_value 1 (`Unop op) v - | Abs -> to_value @@ abs @@ of_value 1 (`Unop op) v + | Neg -> to_value @@ neg @@ of_value 1 (`Unop (U op)) v + | Abs -> to_value @@ abs @@ of_value 1 (`Unop (U op)) v | Sqrt -> to_value' @@ Float.sqrt f | Nearest -> to_value' @@ Float.round f | Ceil -> to_value' @@ Float.ceil f | Floor -> to_value' @@ Float.floor f | Trunc -> to_value' @@ Float.trunc f - | Is_nan -> if Float.is_nan f then Value.True else Value.False - | _ -> Fmt.failwith {|unop: Unsupported f32 operator "%a"|} Ty.Unop.pp op + | Is_normal -> Bool.to_value @@ is_float FP_normal f + | Is_subnormal -> Bool.to_value @@ is_float FP_subnormal f + | Is_negative -> assert false + | Is_positive -> assert false + | Is_infinite -> Bool.to_value @@ is_float FP_infinite f + | Is_nan -> Bool.to_value @@ Float.is_nan f + | Is_zero -> Bool.to_value @@ is_float FP_zero f (* Stolen from Owi *) let copy_sign x y = Int32.logor (abs x) (Int32.logand y Int32.min_int) @@ -631,9 +645,9 @@ module F64 = struct let of_value (i : int) (op : op_type) (v : Value.t) : int64 = of_arg - (function Num (F64 f) -> f | _ -> raise_notrace (Value (Ty_fp 64))) + (function Num (F64 f) -> f | _ -> raise_notrace (Value (Ty (Ty_fp 64)))) i v op - (err_str i op (Ty_fp 64) (Value.type_of v)) + (err_str i op (Ty (Ty_fp 64)) (Value.type_of v)) [@@inline] let of_value' (i : int) (op : op_type) (v : Value.t) : float = @@ -645,18 +659,23 @@ module F64 = struct let neg x = Int64.logxor x Int64.min_int - let unop (op : Ty.Unop.t) (v : Value.t) : Value.t = - let f = of_value' 1 (`Unop op) v in + let unop (op : [ `Ty_fp ] Ty.Unop.op) (v : Value.t) : Value.t = + let f = of_value' 1 (`Unop (U op)) v in match op with - | Neg -> to_value @@ neg @@ of_value 1 (`Unop op) v - | Abs -> to_value @@ abs @@ of_value 1 (`Unop op) v + | Neg -> to_value @@ neg @@ of_value 1 (`Unop (U op)) v + | Abs -> to_value @@ abs @@ of_value 1 (`Unop (U op)) v | Sqrt -> to_value' @@ Float.sqrt f | Nearest -> to_value' @@ Float.round f | Ceil -> to_value' @@ Float.ceil f | Floor -> to_value' @@ Float.floor f | Trunc -> to_value' @@ Float.trunc f - | Is_nan -> if Float.is_nan f then Value.True else Value.False - | _ -> Fmt.failwith {|unop: Unsupported f32 operator "%a"|} Ty.Unop.pp op + | Is_normal -> Bool.to_value @@ is_float FP_normal f + | Is_subnormal -> Bool.to_value @@ is_float FP_subnormal f + | Is_negative -> assert false + | Is_positive -> assert false + | Is_infinite -> Bool.to_value @@ is_float FP_infinite f + | Is_nan -> Bool.to_value @@ Float.is_nan f + | Is_zero -> Bool.to_value @@ is_float FP_zero f let copy_sign x y = Int64.logor (abs x) (Int64.logand y Int64.min_int) @@ -901,7 +920,7 @@ module I64CvtOp = struct (TypeError { index = 1 ; value = v - ; ty = Ty_bitv 64 + ; ty = Ty (Ty_bitv 64) ; op = `Cvtop WrapI64 ; msg = "Cannot wrapI64 on an I64" } ) @@ -965,7 +984,7 @@ module F32CvtOp = struct (TypeError { index = 1 ; value = v - ; ty = Ty_fp 32 + ; ty = Ty (Ty_fp 32) ; op = `Cvtop PromoteF32 ; msg = "F64 must promote a F32" } ) @@ -1029,7 +1048,7 @@ module F64CvtOp = struct (TypeError { index = 1 ; value = v - ; ty = Ty_bitv 64 + ; ty = Ty (Ty_bitv 64) ; op = `Cvtop DemoteF64 ; msg = "F32 must demote a F64" } ) @@ -1039,35 +1058,52 @@ end (* Dispatch *) -let op int real bool str lst bv f32 f64 ty op = - match ty with - | Ty.Ty_int -> int op - | Ty_real -> real op - | Ty_bool -> bool op - | Ty_str -> str op - | Ty_list -> lst op - | Ty_bitv _ -> bv op - | Ty_fp 32 -> f32 op - | Ty_fp 64 -> f64 op - | Ty_fp _ | Ty_app | Ty_unit | Ty_none | Ty_regexp | Ty_roundingMode -> +let unop : type a. a Ty.ty -> a Ty.Unop.op -> Value.t -> Value.t = + fun ty op -> + match (ty, op) with + | Ty_int, ((Neg | Not | Abs) as op) -> Int.unop op + | Ty_real, ((Neg | Abs | Sqrt | Is_nan | Ceil | Floor | Trunc | Nearest) as op) + -> + Real.unop op + | Ty_bool, Not -> Bool.unop Not + | ( Ty_str + , ( ( Length | Trim | Regexp_star | Regexp_loop _ | Regexp_plus | Regexp_opt + | Regexp_comp ) as op ) ) -> + Str.unop op + | Ty_list, ((Head | Tail | Reverse | Length) as op) -> Lst.unop op + | Ty_bitv _, ((Neg | Not | Clz | Ctz | Popcnt) as op) -> Bitv.unop op + | ( Ty_fp 32 + , ((Neg | Abs | Sqrt | Is_nan | Ceil | Floor | Trunc | Nearest) as op) ) -> + F32.unop op + | ( Ty_fp 64 + , ((Neg | Abs | Sqrt | Is_nan | Ceil | Floor | Trunc | Nearest) as op) ) -> + F64.unop op + | (Ty_fp _ | Ty_app | Ty_unit | Ty_none | Ty_regexp | Ty_roundingMode), _ -> assert false -[@@inline] - -let unop = - op Int.unop Real.unop Bool.unop Str.unop Lst.unop Bitv.unop F32.unop F64.unop -let binop = - op Int.binop Real.binop Bool.binop Str.binop Lst.binop Bitv.binop F32.binop - F64.binop +let binop : type a. a Ty.ty -> Ty.Binop.t -> Value.t -> Value.t -> Value.t = + function + | Ty_int -> Int.binop + | Ty_real -> Real.binop + | Ty_bool -> Bool.binop + | Ty_str -> Str.binop + | Ty_list -> Lst.binop + | Ty_bitv _ -> Bitv.binop + | Ty_fp 32 -> F32.binop + | Ty_fp 64 -> F64.binop + | Ty_fp _ | Ty_app | Ty_unit | Ty_none | Ty_regexp | Ty_roundingMode -> + assert false -let triop = function - | Ty.Ty_bool -> Bool.triop +let triop : type a. + a Ty.ty -> Ty.Triop.t -> Value.t -> Value.t -> Value.t -> Value.t = function + | Ty_bool -> Bool.triop | Ty_str -> Str.triop | Ty_list -> Lst.triop | _ -> assert false -let relop = function - | Ty.Ty_int -> Int.relop +let relop : type a. a Ty.ty -> Ty.Relop.t -> Value.t -> Value.t -> bool = + function + | Ty_int -> Int.relop | Ty_real -> Real.relop | Ty_bool -> Bool.relop | Ty_str -> Str.relop @@ -1076,8 +1112,8 @@ let relop = function | Ty_fp 64 -> F64.relop | _ -> assert false -let cvtop = function - | Ty.Ty_int -> Int.cvtop +let cvtop : type a. a Ty.ty -> Ty.Cvtop.t -> Value.t -> Value.t = function + | Ty_int -> Int.cvtop | Ty_real -> Real.cvtop | Ty_str -> Str.cvtop | Ty_bitv 32 -> I32CvtOp.cvtop @@ -1086,8 +1122,9 @@ let cvtop = function | Ty_fp 64 -> F64CvtOp.cvtop | _ -> assert false -let naryop = function - | Ty.Ty_bool -> Bool.naryop +let naryop : type a. a Ty.ty -> Ty.Naryop.t -> Value.t list -> Value.t = + function + | Ty_bool -> Bool.naryop | Ty_str -> Str.naryop | Ty_list -> Lst.naryop | _ -> assert false diff --git a/src/smtml/eval.mli b/src/smtml/eval.mli index 5abcc732..29e82b9b 100644 --- a/src/smtml/eval.mli +++ b/src/smtml/eval.mli @@ -49,28 +49,28 @@ exception (** [unop ty op v] applies a unary operation [op] on the value [v] of type [ty]. Raises [TypeError] if the value does not match the expected type. *) -val unop : Ty.t -> Ty.Unop.t -> Value.t -> Value.t +val unop : 'a Ty.ty -> 'a Ty.Unop.op -> Value.t -> Value.t (** [binop ty op v1 v2] applies a binary operation [op] on the values [v1] and [v2] of type [ty]. Raises [DivideByZero] if the operation involves division by zero. Raises [TypeError] if the values do not match the expected type. *) -val binop : Ty.t -> Ty.Binop.t -> Value.t -> Value.t -> Value.t +val binop : 'a Ty.ty -> Ty.Binop.t -> Value.t -> Value.t -> Value.t (** [triop ty op v1 v2 v3] applies a ternary operation [op] on the values [v1], [v2], and [v3] of type [ty]. Raises [TypeError] if any value does not match the expected type. *) -val triop : Ty.t -> Ty.Triop.t -> Value.t -> Value.t -> Value.t -> Value.t +val triop : 'a Ty.ty -> Ty.Triop.t -> Value.t -> Value.t -> Value.t -> Value.t (** [relop ty op v1 v2] applies a relational operation [op] on the values [v1] and [v2] of type [ty]. Returns [true] if the relation holds, otherwise [false]. Raises [TypeError] if the values do not match the expected type. *) -val relop : Ty.t -> Ty.Relop.t -> Value.t -> Value.t -> bool +val relop : 'a Ty.ty -> Ty.Relop.t -> Value.t -> Value.t -> bool (** [cvtop ty op v] applies a conversion operation [op] on the value [v] of type [ty]. Raises [TypeError] if the value does not match the expected type. *) -val cvtop : Ty.t -> Ty.Cvtop.t -> Value.t -> Value.t +val cvtop : 'a Ty.ty -> Ty.Cvtop.t -> Value.t -> Value.t (** [naryop ty op vs] applies an n-ary operation [op] on the list of values [vs] of type [ty]. Raises [TypeError] if any value does not match the expected type. *) -val naryop : Ty.t -> Ty.Naryop.t -> Value.t list -> Value.t +val naryop : 'a Ty.ty -> Ty.Naryop.t -> Value.t list -> Value.t diff --git a/src/smtml/expr.ml b/src/smtml/expr.ml index ee292e10..f33b0077 100644 --- a/src/smtml/expr.ml +++ b/src/smtml/expr.ml @@ -13,7 +13,7 @@ and expr = | Symbol of Symbol.t | List of t list | App of Symbol.t * t list - | Unop of Ty.t * Ty.Unop.t * t + | Unop : 'a Ty.ty * 'a Ty.Unop.op * t -> expr | Binop of Ty.t * Ty.Binop.t * t * t | Triop of Ty.t * Ty.Triop.t * t * t * t | Relop of Ty.t * Ty.Relop.t * t * t @@ -39,7 +39,9 @@ module Expr = struct | List l1, List l2 -> list_eq l1 l2 | App (s1, l1), App (s2, l2) -> Symbol.equal s1 s2 && list_eq l1 l2 | Unop (t1, op1, e1), Unop (t2, op2, e2) -> - Ty.equal t1 t2 && Ty.Unop.equal op1 op2 && phys_equal e1 e2 + Ty.equal (Ty t1) (Ty t2) + && Ty.Unop.equal (U op1) (U op2) + && phys_equal e1 e2 | Binop (t1, op1, e1, e3), Binop (t2, op2, e2, e4) -> Ty.equal t1 t2 && Ty.Binop.equal op1 op2 && phys_equal e1 e2 && phys_equal e3 e4 @@ -106,11 +108,11 @@ let symbol s = make (Symbol s) let rec ty (hte : t) : Ty.t = match view hte with | Val x -> Value.type_of x - | Ptr _ -> Ty_bitv 32 + | Ptr _ -> Ty (Ty_bitv 32) | Symbol x -> Symbol.type_of x - | List _ -> Ty_list - | App (sym, _) -> begin match sym.ty with Ty_none -> Ty_app | ty -> ty end - | Unop (ty, _, _) -> ty + | List _ -> Ty Ty_list + | App (sym, _) -> begin match sym.ty with Ty Ty_none -> Ty Ty_app | ty -> ty end + | Unop (ty, _, _) -> Ty ty | Binop (ty, _, _, _) -> ty | Triop (_, Ite, _, hte1, hte2) -> let ty1 = ty hte1 in @@ -120,13 +122,15 @@ let rec ty (hte : t) : Ty.t = | Triop (ty, _, _, _, _) -> ty | Relop (ty, _, _, _) -> ty | Cvtop (_, (Zero_extend m | Sign_extend m), hte) -> ( - match ty hte with Ty_bitv n -> Ty_bitv (n + m) | _ -> assert false ) + match ty hte with + | Ty (Ty_bitv n) -> Ty (Ty_bitv (n + m)) + | _ -> assert false ) | Cvtop (ty, _, _) -> ty | Naryop (ty, _, _) -> ty - | Extract (_, h, l) -> Ty_bitv ((h - l) * 8) + | Extract (_, h, l) -> Ty (Ty_bitv ((h - l) * 8)) | Concat (e1, e2) -> ( match (ty e1, ty e2) with - | Ty_bitv n1, Ty_bitv n2 -> Ty_bitv (n1 + n2) + | Ty (Ty_bitv n1), Ty (Ty_bitv n2) -> Ty (Ty_bitv (n1 + n2)) | t1, t2 -> Fmt.failwith "Invalid concat of (%a) with (%a)" Ty.pp t1 Ty.pp t2 ) | Binder (_, _, e) -> ty e @@ -233,7 +237,7 @@ module Pp = struct (Fmt.list ~sep:Fmt.comma pp) v | Unop (ty, op, e) -> - Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp ty Ty.Unop.pp op pp e + Fmt.pf fmt "@[(%a.%a@ %a)@]" Ty.pp (Ty ty) Ty.Unop.pp (U op) pp e | Binop (ty, op, e1, e2) -> Fmt.pf fmt "@[(%a.%a@ %a@ %a)@]" Ty.pp ty Ty.Binop.pp op pp e1 pp e2 @@ -301,7 +305,9 @@ let forall vars body = binder Forall vars body let exists vars body = binder Exists vars body -let raw_unop ty op hte = make (Unop (ty, op, hte)) [@@inline] +let raw_unop : type a. a Ty.ty -> a Ty.Unop.op -> t -> t = + fun ty op hte -> make (Unop (ty, op, hte)) +[@@inline] let normalize_eq_or_ne op (ty', e1, e2) = let make_relop lhs rhs = Relop (ty', op, lhs, rhs) in @@ -309,15 +315,15 @@ let normalize_eq_or_ne op (ty', e1, e2) = if not (Ty.equal ty1 ty2) then make_relop e1 e2 else begin match ty1 with - | Ty_bitv m -> + | Ty Ty_bitv m -> let binop = make (Binop (ty1, Sub, e1, e2)) in let zero = make (Val (Bitv (Bitvector.make Z.zero m))) in make_relop binop zero - | Ty_int -> + | Ty Ty_int -> let binop = make (Binop (ty1, Sub, e1, e2)) in let zero = make (Val (Int Int.zero)) in make_relop binop zero - | Ty_real -> + | Ty Ty_real -> let binop = make (Binop (ty1, Sub, e1, e2)) in let zero = make (Val (Real 0.)) in make_relop binop zero @@ -341,24 +347,28 @@ let negate_relop (hte : t) : t = in make e -let unop ty op hte = +let unop : type a. a Ty.ty -> a Ty.Unop.op -> t -> t = + fun ty op hte -> match (op, view hte) with | Ty.Unop.(Regexp_loop _ | Regexp_star), _ -> raw_unop ty op hte | _, Val v -> value (Eval.unop ty op v) | Not, Unop (_, Not, hte') -> hte' - | Not, Relop (Ty_fp _, _, _, _) -> raw_unop ty op hte + | Not, Relop (Ty Ty_fp _, _, _, _) -> raw_unop ty op hte | Not, Relop (_, _, _, _) -> negate_relop hte | Neg, Unop (_, Neg, hte') -> hte' - | Trim, Cvtop (Ty_real, ToString, _) -> hte + | Trim, Cvtop (Ty Ty_real, ToString, _) -> hte | Head, List (hd :: _) -> hd | Tail, List (_ :: tl) -> make (List tl) | Reverse, List es -> make (List (List.rev es)) | Length, List es -> value (Int (List.length es)) | _ -> raw_unop ty op hte -let raw_binop ty op hte1 hte2 = make (Binop (ty, op, hte1, hte2)) [@@inline] +let raw_binop : type a. a Ty.ty -> Ty.Binop.t -> t -> t -> t = + fun ty op hte1 hte2 -> make (Binop (Ty ty, op, hte1, hte2)) +[@@inline] -let rec binop ty op hte1 hte2 = +let rec binop : type a. a Ty.ty -> Ty.Binop.t -> t -> t -> t = + fun ty op hte1 hte2 -> match (op, view hte1, view hte2) with | Ty.Binop.(String_in_re | Regexp_range), _, _ -> raw_binop ty op hte1 hte2 | op, Val v1, Val v2 -> value (Eval.binop ty op v1 v2) @@ -384,21 +394,21 @@ let rec binop ty op hte1 hte2 = hte1 | (Add | Or), _, Val (Bitv bv) when Bitvector.eqz bv -> hte1 | (And | Mul), _, Val (Bitv bv) when Bitvector.eqz bv -> hte2 - | Add, Binop (ty, Add, x, { node = Val v1; _ }), Val v2 -> + | Add, Binop (Ty ty, Add, x, { node = Val v1; _ }), Val v2 -> let v = value (Eval.binop ty Add v1 v2) in raw_binop ty Add x v - | Sub, Binop (ty, Sub, x, { node = Val v1; _ }), Val v2 -> + | Sub, Binop (Ty ty, Sub, x, { node = Val v1; _ }), Val v2 -> let v = value (Eval.binop ty Add v1 v2) in raw_binop ty Sub x v | Mul, Val (Bitv bv), _ when Bitvector.eq_one bv -> hte2 | Mul, _, Val (Bitv bv) when Bitvector.eq_one bv -> hte1 - | Mul, Binop (ty, Mul, x, { node = Val v1; _ }), Val v2 -> + | Mul, Binop (Ty ty, Mul, x, { node = Val v1; _ }), Val v2 -> let v = value (Eval.binop ty Mul v1 v2) in raw_binop ty Mul x v - | Add, Val v1, Binop (ty, Add, x, { node = Val v2; _ }) -> + | Add, Val v1, Binop (Ty ty, Add, x, { node = Val v2; _ }) -> let v = value (Eval.binop ty Add v1 v2) in raw_binop ty Add v x - | Mul, Val v1, Binop (ty, Mul, x, { node = Val v2; _ }) -> + | Mul, Val v1, Binop (Ty ty, Mul, x, { node = Val v2; _ }) -> let v = value (Eval.binop ty Mul v1 v2) in raw_binop ty Mul v x | At, List es, Val (Int n) -> @@ -414,9 +424,12 @@ let rec binop ty op hte1 hte2 = | List_append, List l0, List l1 -> make (List (l0 @ l1)) | _ -> raw_binop ty op hte1 hte2 -let raw_triop ty op e1 e2 e3 = make (Triop (ty, op, e1, e2, e3)) [@@inline] +let raw_triop : type a. a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t = + fun ty op e1 e2 e3 -> make (Triop (Ty ty, op, e1, e2, e3)) +[@@inline] -let triop ty op e1 e2 e3 = +let triop : type a. a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t = + fun ty op e1 e2 e3 -> match (op, view e1, view e2, view e3) with | Ty.Triop.Ite, Val True, _, _ -> e2 | Ite, Val False, _, _ -> e3 @@ -427,9 +440,12 @@ let triop ty op e1 e2 e3 = raw_triop ty Ite cond r1 else_ | _ -> raw_triop ty op e1 e2 e3 -let raw_relop ty op hte1 hte2 = make (Relop (ty, op, hte1, hte2)) [@@inline] +let raw_relop : type a. a Ty.ty -> Ty.Relop.t -> t -> t -> t = + fun ty op hte1 hte2 -> make (Relop (Ty ty, op, hte1, hte2)) +[@@inline] -let rec relop ty op hte1 hte2 = +let rec relop : type a. a Ty.ty -> Ty.Relop.t -> t -> t -> t = + fun ty op hte1 hte2 -> match (op, view hte1, view hte2) with | op, Val v1, Val v2 -> value (if Eval.relop ty op v1 v2 then True else False) | Ty.Relop.Ne, Val (Real v), _ | Ne, _, Val (Real v) -> @@ -447,8 +463,8 @@ let rec relop ty op hte1 hte2 = | Ne, Val (App (`Op "symbol", [ Str _ ])), _ -> value True | ( Eq - , Symbol ({ ty = Ty_fp prec1; _ } as s1) - , Symbol ({ ty = Ty_fp prec2; _ } as s2) ) + , Symbol ({ ty = Ty (Ty_fp prec1); _ } as s1) + , Symbol ({ ty = Ty (Ty_fp prec2); _ } as s2) ) when prec1 = prec2 && Symbol.equal s1 s2 -> raw_unop Ty_bool Not (raw_unop (Ty_fp prec1) Is_nan hte1) | Eq, Ptr { base = b1; offset = os1 }, Ptr { base = b2; offset = os2 } -> @@ -491,24 +507,30 @@ and relop_list op l1 l2 = binop Ty_bool And acc @@ match (ty a, ty b) with - | Ty_real, Ty_real -> relop Ty_real Eq a b + | Ty Ty_real, Ty Ty_real -> relop Ty_real Eq a b | _ -> relop Ty_bool Eq a b ) (value True) l1 l2 | Ne, _, _ -> unop Ty_bool Not @@ relop_list Eq l1 l2 | (Lt | LtU | Gt | GtU | Le | LeU | Ge | GeU), _, _ -> assert false -let raw_cvtop ty op hte = make (Cvtop (ty, op, hte)) [@@inline] +let raw_cvtop : type a. a Ty.ty -> Ty.Cvtop.t -> t -> t = + fun ty op hte -> make (Cvtop (Ty ty, op, hte)) +[@@inline] -let cvtop ty op hte = +let cvtop : type a. a Ty.ty -> Ty.Cvtop.t -> t -> t = + fun ty op hte -> match (op, view hte) with | Ty.Cvtop.String_to_re, _ -> raw_cvtop ty op hte | _, Val v -> value (Eval.cvtop ty op v) - | String_to_float, Cvtop (Ty_real, ToString, real) -> real + | String_to_float, Cvtop (Ty Ty_real, ToString, real) -> real | _ -> raw_cvtop ty op hte -let raw_naryop ty op es = make (Naryop (ty, op, es)) [@@inline] +let raw_naryop : type a. a Ty.ty -> Ty.Naryop.t -> t list -> t = + fun ty op es -> make (Naryop (Ty ty, op, es)) +[@@inline] -let naryop ty op es = +let naryop : type a. a Ty.ty -> Ty.Naryop.t -> t list -> t = + fun ty op es -> if List.for_all (fun e -> match view e with Val _ -> true | _ -> false) es then let vs = @@ -519,11 +541,11 @@ let naryop ty op es = match (ty, op, List.map view es) with | ( Ty_str , Concat - , [ Naryop (Ty_str, Concat, l1); Naryop (Ty_str, Concat, l2) ] ) -> + , [ Naryop (Ty Ty_str, Concat, l1); Naryop (Ty Ty_str, Concat, l2) ] ) -> raw_naryop Ty_str Concat (l1 @ l2) - | Ty_str, Concat, [ Naryop (Ty_str, Concat, htes); hte ] -> + | Ty_str, Concat, [ Naryop (Ty Ty_str, Concat, htes); hte ] -> raw_naryop Ty_str Concat (htes @ [ make hte ]) - | Ty_str, Concat, [ hte; Naryop (Ty_str, Concat, htes) ] -> + | Ty_str, Concat, [ hte; Naryop (Ty Ty_str, Concat, htes) ] -> raw_naryop Ty_str Concat (make hte :: htes) | _ -> raw_naryop ty op es @@ -539,7 +561,7 @@ let extract (hte : t) ~(high : int) ~(low : int) : t = | ( Cvtop ( _ , (Zero_extend 24 | Sign_extend 24) - , ({ node = Symbol { ty = Ty_bitv 8; _ }; _ } as sym) ) + , ({ node = Symbol { ty = Ty (Ty_bitv 8); _ }; _ } as sym) ) , 1 , 0 ) -> sym @@ -574,23 +596,23 @@ let rec simplify_expr ?(in_relop = false) (hte : t) : t = | Unop (ty, op, e) -> let e = simplify_expr ~in_relop e in unop ty op e - | Binop (ty, op, e1, e2) -> + | Binop (Ty ty, op, e1, e2) -> let e1 = simplify_expr ~in_relop e1 in let e2 = simplify_expr ~in_relop e2 in binop ty op e1 e2 - | Relop (ty, op, e1, e2) -> + | Relop (Ty ty, op, e1, e2) -> let e1 = simplify_expr ~in_relop:true e1 in let e2 = simplify_expr ~in_relop:true e2 in relop ty op e1 e2 - | Triop (ty, op, c, e1, e2) -> + | Triop (Ty ty, op, c, e1, e2) -> let c = simplify_expr ~in_relop c in let e1 = simplify_expr ~in_relop e1 in let e2 = simplify_expr ~in_relop e2 in triop ty op c e1 e2 - | Cvtop (ty, op, e) -> + | Cvtop (Ty ty, op, e) -> let e = simplify_expr ~in_relop e in cvtop ty op e - | Naryop (ty, op, es) -> + | Naryop (Ty ty, op, es) -> let es = List.map (simplify_expr ~in_relop) es in naryop ty op es | Extract (s, high, low) -> @@ -684,7 +706,7 @@ end module Make (T : sig type elt - val ty : Ty.t + val ty : [> `Ty_bitv | `Ty_fp ] Ty.ty val value : elt -> Value.t end) = @@ -693,7 +715,7 @@ struct let v i = value (T.value i) - let sym x = symbol Symbol.(x @: T.ty) + let sym x = symbol Symbol.(x @: Ty T.ty) let ( ~- ) e = unop T.ty Neg e diff --git a/src/smtml/expr.mli b/src/smtml/expr.mli index cb5b8fa0..8087dcea 100644 --- a/src/smtml/expr.mli +++ b/src/smtml/expr.mli @@ -23,7 +23,7 @@ and expr = private | Symbol of Symbol.t (** A symbolic variable. *) | List of t list (** A list of expressions. *) | App of Symbol.t * t list (** Function application. *) - | Unop of Ty.t * Ty.Unop.t * t (** Unary operation. *) + | Unop : 'a Ty.ty * 'a Ty.Unop.op * t -> expr (** Unary operation. *) | Binop of Ty.t * Ty.Binop.t * t * t (** Binary operation. *) | Triop of Ty.t * Ty.Triop.t * t * t * t (** Ternary operation. *) | Relop of Ty.t * Ty.Relop.t * t * t (** Relational operation. *) @@ -122,24 +122,24 @@ val exists : t list -> t -> t (** These constructors apply simplifications during construction. *) (** [unop ty op expr] applies a unary operation with simplification. *) -val unop : Ty.t -> Ty.Unop.t -> t -> t +val unop : 'a Ty.ty -> 'a Ty.Unop.op -> t -> t (** [binop ty op expr1 expr2] applies a binary operation with simplification. *) -val binop : Ty.t -> Ty.Binop.t -> t -> t -> t +val binop : 'a Ty.ty -> Ty.Binop.t -> t -> t -> t (** [triop ty op expr1 expr2 expr3] applies a ternary operation with simplification. *) -val triop : Ty.t -> Ty.Triop.t -> t -> t -> t -> t +val triop : 'a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t (** [relop ty op expr1 expr2] applies a relational operation with simplification. *) -val relop : Ty.t -> Ty.Relop.t -> t -> t -> t +val relop : 'a Ty.ty -> Ty.Relop.t -> t -> t -> t (** [cvtop ty op expr] applies a conversion operation with simplification. *) -val cvtop : Ty.t -> Ty.Cvtop.t -> t -> t +val cvtop : 'a Ty.ty -> Ty.Cvtop.t -> t -> t (** [naryop ty op exprs] applies an N-ary operation with simplification. *) -val naryop : Ty.t -> Ty.Naryop.t -> t list -> t +val naryop : 'a Ty.ty -> Ty.Naryop.t -> t list -> t (** [extract expr ~high ~low] extracts a bit range with simplification. *) val extract : t -> high:int -> low:int -> t @@ -173,7 +173,7 @@ val concat : t -> t -> t ]} which would typically be the result of the smart constructor [unop]. *) -val raw_unop : Ty.t -> Ty.Unop.t -> t -> t +val raw_unop : 'a Ty.ty -> 'a Ty.Unop.op -> t -> t (** [raw_binop ty op expr1 expr2] applies a binary operation, creating a node without immediate simplification. @@ -200,7 +200,7 @@ val raw_unop : Ty.t -> Ty.Unop.t -> t -> t ]} which would typically be the result of the smart constructor [binop]. *) -val raw_binop : Ty.t -> Ty.Binop.t -> t -> t -> t +val raw_binop : 'a Ty.ty -> Ty.Binop.t -> t -> t -> t (** [raw_triop ty op expr1 expr2 expr3] applies a ternary operation, creating a node without immediate simplification. @@ -227,7 +227,7 @@ val raw_binop : Ty.t -> Ty.Binop.t -> t -> t -> t ]} which would typically be the result of the smart constructor [triop]. *) -val raw_triop : Ty.t -> Ty.Triop.t -> t -> t -> t -> t +val raw_triop : 'a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t (** [raw_relop ty op expr1 expr2] applies a relational operation, creating a node without immediate simplification. @@ -254,7 +254,7 @@ val raw_triop : Ty.t -> Ty.Triop.t -> t -> t -> t -> t ]} which would typically be the result of the smart constructor [relop]. *) -val raw_relop : Ty.t -> Ty.Relop.t -> t -> t -> t +val raw_relop : 'a Ty.ty -> Ty.Relop.t -> t -> t -> t (** [raw_cvtop ty op expr] applies a conversion operation, creating a node without immediate simplification. @@ -281,11 +281,11 @@ val raw_relop : Ty.t -> Ty.Relop.t -> t -> t -> t ]} which would typically be the result of the smart constructor [cvtop]. *) -val raw_cvtop : Ty.t -> Ty.Cvtop.t -> t -> t +val raw_cvtop : 'a Ty.ty -> Ty.Cvtop.t -> t -> t (** [raw_naryop ty op exprs] applies an N-ary operation without simplification. *) -val raw_naryop : Ty.t -> Ty.Naryop.t -> t list -> t +val raw_naryop : 'a Ty.ty -> Ty.Naryop.t -> t list -> t (** [raw_extract expr ~high ~low] extracts a bit range without simplification. *) diff --git a/src/smtml/expr_raw.ml b/src/smtml/expr_raw.ml index 26bec65b..e4441748 100644 --- a/src/smtml/expr_raw.ml +++ b/src/smtml/expr_raw.ml @@ -12,7 +12,7 @@ include ( | Symbol of Symbol.t | List of t list | App of Symbol.t * t list - | Unop of Ty.t * Ty.Unop.t * t + | Unop : 'a Ty.ty * 'a Ty.Unop.op * t -> expr | Binop of Ty.t * Ty.Binop.t * t * t | Triop of Ty.t * Ty.Triop.t * t * t * t | Relop of Ty.t * Ty.Relop.t * t * t @@ -64,17 +64,17 @@ include ( val exists : t list -> t -> t - val raw_unop : Ty.t -> Ty.Unop.t -> t -> t + val raw_unop : 'a Ty.ty -> 'a Ty.Unop.op -> t -> t - val raw_binop : Ty.t -> Ty.Binop.t -> t -> t -> t + val raw_binop : 'a Ty.ty -> Ty.Binop.t -> t -> t -> t - val raw_triop : Ty.t -> Ty.Triop.t -> t -> t -> t -> t + val raw_triop : 'a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t - val raw_relop : Ty.t -> Ty.Relop.t -> t -> t -> t + val raw_relop : 'a Ty.ty -> Ty.Relop.t -> t -> t -> t - val raw_cvtop : Ty.t -> Ty.Cvtop.t -> t -> t + val raw_cvtop : 'a Ty.ty -> Ty.Cvtop.t -> t -> t - val raw_naryop : Ty.t -> Ty.Naryop.t -> t list -> t + val raw_naryop : 'a Ty.ty -> Ty.Naryop.t -> t list -> t val raw_extract : t -> high:int -> low:int -> t diff --git a/src/smtml/expr_raw.mli b/src/smtml/expr_raw.mli index 00e92fac..4a8c2436 100644 --- a/src/smtml/expr_raw.mli +++ b/src/smtml/expr_raw.mli @@ -19,7 +19,7 @@ and expr = private | Symbol of Symbol.t (** A symbolic variable. *) | List of t list (** A list of expressions. *) | App of Symbol.t * t list (** Function application. *) - | Unop of Ty.t * Ty.Unop.t * t (** Unary operation. *) + | Unop : 'a Ty.ty * 'a Ty.Unop.op * t -> expr (** Unary operation. *) | Binop of Ty.t * Ty.Binop.t * t * t (** Binary operation. *) | Triop of Ty.t * Ty.Triop.t * t * t * t (** Ternary operation. *) | Relop of Ty.t * Ty.Relop.t * t * t (** Relational operation. *) @@ -118,22 +118,22 @@ val exists : t list -> t -> t (** These constructors do NOT apply simplifications during construction. *) (** [unop ty op expr] applies a raw unary. *) -val unop : Ty.t -> Ty.Unop.t -> t -> t +val unop : 'a Ty.ty -> 'a Ty.Unop.op -> t -> t (** [binop ty op expr1 expr2] applies a raw binary. *) -val binop : Ty.t -> Ty.Binop.t -> t -> t -> t +val binop : 'a Ty.ty -> Ty.Binop.t -> t -> t -> t (** [triop ty op expr1 expr2 expr3] applies a raw ternary operation. *) -val triop : Ty.t -> Ty.Triop.t -> t -> t -> t -> t +val triop : 'a Ty.ty -> Ty.Triop.t -> t -> t -> t -> t (** [relop ty op expr1 expr2] applies a raw relational operation. *) -val relop : Ty.t -> Ty.Relop.t -> t -> t -> t +val relop : 'a Ty.ty -> Ty.Relop.t -> t -> t -> t (** [cvtop ty op expr] applies a raw conversion. *) -val cvtop : Ty.t -> Ty.Cvtop.t -> t -> t +val cvtop : 'a Ty.ty -> Ty.Cvtop.t -> t -> t (** [naryop ty op exprs] applies a raw N-ary operation. *) -val naryop : Ty.t -> Ty.Naryop.t -> t list -> t +val naryop : 'a Ty.ty -> Ty.Naryop.t -> t list -> t (** [extract expr ~high ~low] extracts a bit range. *) val extract : t -> high:int -> low:int -> t diff --git a/src/smtml/lexer.mll b/src/smtml/lexer.mll index 8dee6f9c..f89c9905 100644 --- a/src/smtml/lexer.mll +++ b/src/smtml/lexer.mll @@ -13,201 +13,201 @@ let keywords = let tbl = Hashtbl.create 256 in Array.iter (fun (k, v) -> Hashtbl.add tbl k v) - [| ("int" , TYPE (Ty_int)) - ; ("real", TYPE (Ty_real)) - ; ("bool", TYPE (Ty_bool)) - ; ("str" , TYPE (Ty_str)) - ; ("i32" , TYPE (Ty_bitv 32)) - ; ("i64" , TYPE (Ty_bitv 64)) - ; ("f32" , TYPE (Ty_fp 32)) - ; ("f64" , TYPE (Ty_fp 64)) - ; ("not", UNARY (Ty_bool, Not)) - ; ("bool.not", UNARY (Ty_bool, Not)) (* To deprecate *) - ; ("and", BINARY (Ty_bool, And)) - ; ("bool.and", BINARY (Ty_bool, And)) (* To deprecate *) - ; ("or", BINARY (Ty_bool, Or)) - ; ("bool.or", BINARY (Ty_bool, Or)) (* To deprecate *) - ; ("xor", BINARY (Ty_bool, Xor)) - ; ("bool.xor", BINARY (Ty_bool, Xor)) - ; ("=", RELOP (Ty_bool, Eq)) - ; ("bool.eq", RELOP (Ty_bool, Eq)) (* To deprecate *) - ; ("distinct", RELOP (Ty_bool, Ne)) - ; ("bool.ne", RELOP (Ty_bool, Ne)) (* To deprecate *) - ; ("ite", TERNARY (Ty_bool, Ite)) - ; ("bool.ite", TERNARY (Ty_bool, Ite)) (* To deprecate *) - ; ("int.neg", UNARY (Ty_int, Neg)) - ; ("int.add", BINARY (Ty_int, Add)) - ; ("int.sub", BINARY (Ty_int, Sub)) - ; ("int.div", BINARY (Ty_int, Div)) - ; ("int.mul", BINARY (Ty_int, Mul)) - ; ("int.rem", BINARY (Ty_int, Rem)) - ; ("int.pow", BINARY (Ty_int, Pow)) - ; ("int.eq", RELOP (Ty_bool, Eq)) - ; ("int.ne", RELOP (Ty_bool, Ne)) - ; ("int.lt", RELOP (Ty_int, Lt)) - ; ("int.le", RELOP (Ty_int, Le)) - ; ("int.gt", RELOP (Ty_int, Gt)) - ; ("int.ge", RELOP (Ty_int, Ge)) - ; ("int.to_string", CVTOP (Ty_int, ToString)) - ; ("int.of_string", CVTOP (Ty_int, OfString)) - ; ("int.reinterpret_real", CVTOP (Ty_int, Reinterpret_float)) - ; ("real.neg", UNARY (Ty_real, Neg)) - ; ("real.abs", UNARY (Ty_real, Abs)) - ; ("real.sqrt", UNARY (Ty_real, Sqrt)) - ; ("real.nearest", UNARY (Ty_real, Nearest)) - ; ("real.is_nan", UNARY (Ty_real, Is_nan)) - ; ("real.add", BINARY (Ty_real, Add)) - ; ("real.sub", BINARY (Ty_real, Sub)) - ; ("real.div", BINARY (Ty_real, Div)) - ; ("real.mul", BINARY (Ty_real, Mul)) - ; ("real.rem", BINARY (Ty_real, Rem)) - ; ("real.min", BINARY (Ty_real, Min)) - ; ("real.max", BINARY (Ty_real, Max)) - ; ("real.eq", RELOP (Ty_bool, Eq)) - ; ("real.ne", RELOP (Ty_bool, Ne)) - ; ("real.lt", RELOP (Ty_real, Lt)) - ; ("real.le", RELOP (Ty_real, Le)) - ; ("real.gt", RELOP (Ty_real, Gt)) - ; ("real.ge", RELOP (Ty_real, Ge)) - ; ("real.reinterpret_int", CVTOP (Ty_real, Reinterpret_int)) - ; ("real.to_string", CVTOP (Ty_real, ToString)) - ; ("real.of_string", CVTOP (Ty_real, OfString)) - ; ("str.len", UNARY (Ty_str, Length)) - ; ("str.at", BINARY (Ty_str, At)) - ; ("str.++", NARY (Ty_str, Concat)) - ; ("str.prefixof", BINARY (Ty_str, String_prefix)) - ; ("str.suffixof", BINARY (Ty_str, String_suffix)) - ; ("str.contains", BINARY (Ty_str, String_contains)) - ; ("str.substr", TERNARY (Ty_str, String_extract)) - ; ("str.replace", TERNARY (Ty_str, String_replace)) - ; ("str.indexof", TERNARY (Ty_str, String_index)) - ; ("str.to_code", CVTOP (Ty_str, String_to_code)) - ; ("str.from_code", CVTOP (Ty_str, String_from_code)) - ; ("str.to_int", CVTOP (Ty_str, String_to_int)) - ; ("str.from_int", CVTOP (Ty_str, String_from_int)) - ; ("i32.neg", UNARY (Ty_bitv 32, Neg)) - ; ("i32.clz", UNARY (Ty_bitv 32, Clz)) - ; ("i32.not", UNARY (Ty_bitv 32, Not)) - ; ("i32.add", BINARY (Ty_bitv 32, Add)) - ; ("i32.sub", BINARY (Ty_bitv 32, Sub)) - ; ("i32.div", BINARY (Ty_bitv 32, Div)) - ; ("i32.div_u", BINARY (Ty_bitv 32, DivU)) - ; ("i32.and", BINARY (Ty_bitv 32, And)) - ; ("i32.or", BINARY (Ty_bitv 32, Or)) - ; ("i32.xor", BINARY (Ty_bitv 32, Xor)) - ; ("i32.mul", BINARY (Ty_bitv 32, Mul)) - ; ("i32.shl", BINARY (Ty_bitv 32, Shl)) - ; ("i32.shr", BINARY (Ty_bitv 32, ShrA)) - ; ("i32.shr_u", BINARY (Ty_bitv 32, ShrL)) - ; ("i32.rem", BINARY (Ty_bitv 32, Rem)) - ; ("i32.rem_u", BINARY (Ty_bitv 32, RemU)) - ; ("i32.eq", RELOP (Ty_bool, Eq)) - ; ("i32.ne", RELOP (Ty_bool, Ne)) - ; ("i32.lt_u", RELOP (Ty_bitv 32, LtU)) - ; ("i32.lt", RELOP (Ty_bitv 32, Lt)) - ; ("i32.le_u", RELOP (Ty_bitv 32, LeU)) - ; ("i32.le", RELOP (Ty_bitv 32, Le)) - ; ("i32.gt_u", RELOP (Ty_bitv 32, GtU)) - ; ("i32.gt", RELOP (Ty_bitv 32, Gt)) - ; ("i32.ge_u", RELOP (Ty_bitv 32, GeU)) - ; ("i32.ge", RELOP (Ty_bitv 32, Ge)) - ; ("i32.to_bool", CVTOP (Ty_bitv 32, ToBool)) - ; ("i32.of_bool", CVTOP (Ty_bitv 32, OfBool)) - ; ("i32.trunc_f32_s", CVTOP (Ty_bitv 32, TruncSF32)) - ; ("i32.trunc_f32_u", CVTOP (Ty_bitv 32, TruncUF32)) - ; ("i32.trunc_f64_s", CVTOP (Ty_bitv 32, TruncSF64)) - ; ("i32.trunc_f64_u", CVTOP (Ty_bitv 32, TruncUF64)) - ; ("i32.reinterpret_float", CVTOP (Ty_bitv 32, Reinterpret_float)) - ; ("i32.wrap_i64", CVTOP (Ty_bitv 32, WrapI64)) - ; ("i32.extend_i16_s", CVTOP (Ty_bitv 32, Sign_extend 16)) - ; ("i32.extend_i16_u", CVTOP (Ty_bitv 32, Zero_extend 16)) - ; ("i32.extend_i24_s", CVTOP (Ty_bitv 32, Sign_extend 24)) - ; ("i32.extend_i24_u", CVTOP (Ty_bitv 32, Zero_extend 24)) - ; ("i64.neg", UNARY (Ty_bitv 64, Neg)) - ; ("i64.clz", UNARY (Ty_bitv 64, Clz)) - ; ("i64.not", UNARY (Ty_bitv 64, Not)) - ; ("i64.add", BINARY (Ty_bitv 64, Add)) - ; ("i64.sub", BINARY (Ty_bitv 64, Sub)) - ; ("i64.div", BINARY (Ty_bitv 64, Div)) - ; ("i64.div_u", BINARY (Ty_bitv 64, DivU)) - ; ("i64.and", BINARY (Ty_bitv 64, And)) - ; ("i64.or", BINARY (Ty_bitv 64, Or)) - ; ("i64.xor", BINARY (Ty_bitv 64, Xor)) - ; ("i64.mul", BINARY (Ty_bitv 64, Mul)) - ; ("i64.shl", BINARY (Ty_bitv 64, Shl)) - ; ("i64.shr", BINARY (Ty_bitv 64, ShrA)) - ; ("i64.shr_u", BINARY (Ty_bitv 64, ShrL)) - ; ("i64.rem", BINARY (Ty_bitv 64, Rem)) - ; ("i64.rem_u", BINARY (Ty_bitv 64, RemU)) - ; ("i64.eq", RELOP (Ty_bool, Eq)) - ; ("i64.ne", RELOP (Ty_bool, Ne)) - ; ("i64.lt_u", RELOP (Ty_bitv 64, LtU)) - ; ("i64.lt", RELOP (Ty_bitv 64, Lt)) - ; ("i64.le_u", RELOP (Ty_bitv 64, LeU)) - ; ("i64.le", RELOP (Ty_bitv 64, Le)) - ; ("i64.gt_u", RELOP (Ty_bitv 64, GtU)) - ; ("i64.gt", RELOP (Ty_bitv 64, Gt)) - ; ("i64.ge_u", RELOP (Ty_bitv 64, GeU)) - ; ("i64.ge", RELOP (Ty_bitv 64, Ge)) - ; ("i64.trunc_f32_s", CVTOP (Ty_bitv 64, TruncSF32)) - ; ("i64.trunc_f32_u", CVTOP (Ty_bitv 64, TruncUF32)) - ; ("i64.trunc_f64_s", CVTOP (Ty_bitv 64, TruncSF64)) - ; ("i64.trunc_f64_u", CVTOP (Ty_bitv 64, TruncUF64)) - ; ("i64.reinterpret_float", CVTOP (Ty_bitv 64, Reinterpret_float)) - ; ("i64.extend_i32_s", CVTOP (Ty_bitv 64, Sign_extend 32)) - ; ("i64.extend_i32_u", CVTOP (Ty_bitv 64, Zero_extend 32)) - ; ("f32.neg", UNARY (Ty_fp 32, Neg)) - ; ("f32.abs", UNARY (Ty_fp 32, Abs)) - ; ("f32.sqrt", UNARY (Ty_fp 32, Sqrt)) - ; ("f32.nearest",UNARY (Ty_fp 32, Nearest) ) - ; ("f32.is_nan", UNARY (Ty_fp 32, Is_nan)) - ; ("f32.ceil", UNARY (Ty_fp 32, Ceil)) - ; ("f32.floor", UNARY (Ty_fp 32, Floor)) - ; ("f32.trunc", UNARY (Ty_fp 32, Trunc)) - ; ("f32.add", BINARY (Ty_fp 32, Add)) - ; ("f32.sub", BINARY (Ty_fp 32, Sub)) - ; ("f32.mul", BINARY (Ty_fp 32, Mul)) - ; ("f32.div", BINARY (Ty_fp 32, Div)) - ; ("f32.min", BINARY (Ty_fp 32, Min)) - ; ("f32.max", BINARY (Ty_fp 32, Max)) - ; ("f32.rem", BINARY (Ty_fp 32, Rem)) - ; ("f32.eq", RELOP (Ty_fp 32, Eq)) - ; ("f32.ne", RELOP (Ty_fp 32, Ne)) - ; ("f32.lt", RELOP (Ty_fp 32, Lt)) - ; ("f32.le", RELOP (Ty_fp 32, Le)) - ; ("f32.gt", RELOP (Ty_fp 32, Gt)) - ; ("f32.ge", RELOP (Ty_fp 32, Ge)) - ; ("f32.convert_i32_s", CVTOP (Ty_fp 32, ConvertSI32)) - ; ("f32.convert_i32_u", CVTOP (Ty_fp 32, ConvertUI32)) - ; ("f32.convert_i64_s", CVTOP (Ty_fp 32, ConvertSI32)) - ; ("f32.demote_f64", CVTOP (Ty_fp 32, DemoteF64)) - ; ("f32.reinterpret_int", CVTOP (Ty_fp 32, Reinterpret_int)) - ; ("f64.neg", UNARY (Ty_fp 64, Neg)) - ; ("f64.abs", UNARY (Ty_fp 64, Abs)) - ; ("f64.sqrt", UNARY (Ty_fp 64, Sqrt)) - ; ("f64.nearest",UNARY (Ty_fp 64, Nearest) ) - ; ("f64.is_nan", UNARY (Ty_fp 64, Is_nan)) - ; ("f64.ceil", UNARY (Ty_fp 32, Ceil)) - ; ("f64.floor", UNARY (Ty_fp 32, Floor)) - ; ("f64.trunc", UNARY (Ty_fp 32, Trunc)) - ; ("f64.add", BINARY (Ty_fp 64, Add)) - ; ("f64.sub", BINARY (Ty_fp 64, Sub)) - ; ("f64.mul", BINARY (Ty_fp 64, Mul)) - ; ("f64.div", BINARY (Ty_fp 64, Div)) - ; ("f64.min", BINARY (Ty_fp 64, Min)) - ; ("f64.max", BINARY (Ty_fp 64, Max)) - ; ("f64.rem", BINARY (Ty_fp 64, Rem)) - ; ("f64.eq", RELOP (Ty_fp 64, Eq)) - ; ("f64.ne", RELOP (Ty_fp 64, Ne)) - ; ("f64.lt", RELOP (Ty_fp 64, Lt)) - ; ("f64.le", RELOP (Ty_fp 64, Le)) - ; ("f64.gt", RELOP (Ty_fp 64, Gt)) - ; ("f64.ge", RELOP (Ty_fp 64, Ge)) - ; ("f64.convert_i32_s", CVTOP (Ty_fp 64, ConvertSI32)) - ; ("f64.convert_i32_u", CVTOP (Ty_fp 64, ConvertUI32)) - ; ("f64.convert_i64_s", CVTOP (Ty_fp 64, ConvertSI32)) - ; ("f64.promote_f32", CVTOP (Ty_fp 64, PromoteF32)) - ; ("f64.reinterpret_int", CVTOP (Ty_fp 64, Reinterpret_int)) + [| ("int" , TYPE (Ty Ty_int)) + ; ("real", TYPE (Ty Ty_real)) + ; ("bool", TYPE (Ty Ty_bool)) + ; ("str" , TYPE (Ty Ty_str)) + ; ("i32" , TYPE (Ty (Ty_bitv 32))) + ; ("i64" , TYPE (Ty (Ty_bitv 64))) + ; ("f32" , TYPE (Ty (Ty_fp 32))) + ; ("f64" , TYPE (Ty (Ty_fp 64))) + ; ("not", UNARY (Utils_parse.U (Ty_bool, Not))) + ; ("bool.not", UNARY (Utils_parse.U (Ty_bool, Not))) (* To deprecate *) + ; ("and", BINARY (Ty Ty_bool, And)) + ; ("bool.and", BINARY (Ty Ty_bool, And)) (* To deprecate *) + ; ("or", BINARY (Ty Ty_bool, Or)) + ; ("bool.or", BINARY (Ty Ty_bool, Or)) (* To deprecate *) + ; ("xor", BINARY (Ty Ty_bool, Xor)) + ; ("bool.xor", BINARY (Ty Ty_bool, Xor)) + ; ("=", RELOP (Ty Ty_bool, Eq)) + ; ("bool.eq", RELOP (Ty Ty_bool, Eq)) (* To deprecate *) + ; ("distinct", RELOP (Ty Ty_bool, Ne)) + ; ("bool.ne", RELOP (Ty Ty_bool, Ne)) (* To deprecate *) + ; ("ite", TERNARY (Ty Ty_bool, Ite)) + ; ("bool.ite", TERNARY (Ty Ty_bool, Ite)) (* To deprecate *) + ; ("int.neg", UNARY (Utils_parse.U (Ty_int, Neg))) + ; ("int.add", BINARY (Ty Ty_int, Add)) + ; ("int.sub", BINARY (Ty Ty_int, Sub)) + ; ("int.div", BINARY (Ty Ty_int, Div)) + ; ("int.mul", BINARY (Ty Ty_int, Mul)) + ; ("int.rem", BINARY (Ty Ty_int, Rem)) + ; ("int.pow", BINARY (Ty Ty_int, Pow)) + ; ("int.eq", RELOP (Ty Ty_bool, Eq)) + ; ("int.ne", RELOP (Ty Ty_bool, Ne)) + ; ("int.lt", RELOP (Ty Ty_int, Lt)) + ; ("int.le", RELOP (Ty Ty_int, Le)) + ; ("int.gt", RELOP (Ty Ty_int, Gt)) + ; ("int.ge", RELOP (Ty Ty_int, Ge)) + ; ("int.to_string", CVTOP (Ty Ty_int, ToString)) + ; ("int.of_string", CVTOP (Ty Ty_int, OfString)) + ; ("int.reinterpret_real", CVTOP (Ty Ty_int, Reinterpret_float)) + ; ("real.neg", UNARY (Utils_parse.U (Ty_real, Neg))) + ; ("real.abs", UNARY (Utils_parse.U (Ty_real, Abs))) + ; ("real.sqrt", UNARY (Utils_parse.U (Ty_real, Sqrt))) + ; ("real.nearest", UNARY (Utils_parse.U (Ty_real, Nearest))) + ; ("real.is_nan", UNARY (Utils_parse.U (Ty_real, Is_nan))) + ; ("real.add", BINARY (Ty Ty_real, Add)) + ; ("real.sub", BINARY (Ty Ty_real, Sub)) + ; ("real.div", BINARY (Ty Ty_real, Div)) + ; ("real.mul", BINARY (Ty Ty_real, Mul)) + ; ("real.rem", BINARY (Ty Ty_real, Rem)) + ; ("real.min", BINARY (Ty Ty_real, Min)) + ; ("real.max", BINARY (Ty Ty_real, Max)) + ; ("real.eq", RELOP (Ty Ty_bool, Eq)) + ; ("real.ne", RELOP (Ty Ty_bool, Ne)) + ; ("real.lt", RELOP (Ty Ty_real, Lt)) + ; ("real.le", RELOP (Ty Ty_real, Le)) + ; ("real.gt", RELOP (Ty Ty_real, Gt)) + ; ("real.ge", RELOP (Ty Ty_real, Ge)) + ; ("real.reinterpret_int", CVTOP (Ty Ty_real, Reinterpret_int)) + ; ("real.to_string", CVTOP (Ty Ty_real, ToString)) + ; ("real.of_string", CVTOP (Ty Ty_real, OfString)) + ; ("str.len", UNARY (Utils_parse.U (Ty_str, Length))) + ; ("str.at", BINARY (Ty Ty_str, At)) + ; ("str.++", NARY (Ty Ty_str, Concat)) + ; ("str.prefixof", BINARY (Ty Ty_str, String_prefix)) + ; ("str.suffixof", BINARY (Ty Ty_str, String_suffix)) + ; ("str.contains", BINARY (Ty Ty_str, String_contains)) + ; ("str.substr", TERNARY (Ty Ty_str, String_extract)) + ; ("str.replace", TERNARY (Ty Ty_str, String_replace)) + ; ("str.indexof", TERNARY (Ty Ty_str, String_index)) + ; ("str.to_code", CVTOP (Ty Ty_str, String_to_code)) + ; ("str.from_code", CVTOP (Ty Ty_str, String_from_code)) + ; ("str.to_int", CVTOP (Ty Ty_str, String_to_int)) + ; ("str.from_int", CVTOP (Ty Ty_str, String_from_int)) + ; ("i32.neg", UNARY (Utils_parse.U (Ty_bitv 32, Neg))) + ; ("i32.clz", UNARY (Utils_parse.U (Ty_bitv 32, Clz))) + ; ("i32.not", UNARY (Utils_parse.U (Ty_bitv 32, Not))) + ; ("i32.add", BINARY (Ty (Ty_bitv 32), Add)) + ; ("i32.sub", BINARY (Ty (Ty_bitv 32), Sub)) + ; ("i32.div", BINARY (Ty (Ty_bitv 32), Div)) + ; ("i32.div_u", BINARY (Ty (Ty_bitv 32), DivU)) + ; ("i32.and", BINARY (Ty (Ty_bitv 32), And)) + ; ("i32.or", BINARY (Ty (Ty_bitv 32), Or)) + ; ("i32.xor", BINARY (Ty (Ty_bitv 32), Xor)) + ; ("i32.mul", BINARY (Ty (Ty_bitv 32), Mul)) + ; ("i32.shl", BINARY (Ty (Ty_bitv 32), Shl)) + ; ("i32.shr", BINARY (Ty (Ty_bitv 32), ShrA)) + ; ("i32.shr_u", BINARY (Ty (Ty_bitv 32), ShrL)) + ; ("i32.rem", BINARY (Ty (Ty_bitv 32), Rem)) + ; ("i32.rem_u", BINARY (Ty (Ty_bitv 32), RemU)) + ; ("i32.eq", RELOP (Ty Ty_bool, Eq)) + ; ("i32.ne", RELOP (Ty Ty_bool, Ne)) + ; ("i32.lt_u", RELOP (Ty (Ty_bitv 32), LtU)) + ; ("i32.lt", RELOP (Ty (Ty_bitv 32), Lt)) + ; ("i32.le_u", RELOP (Ty (Ty_bitv 32), LeU)) + ; ("i32.le", RELOP (Ty (Ty_bitv 32), Le)) + ; ("i32.gt_u", RELOP (Ty (Ty_bitv 32), GtU)) + ; ("i32.gt", RELOP (Ty (Ty_bitv 32), Gt)) + ; ("i32.ge_u", RELOP (Ty (Ty_bitv 32), GeU)) + ; ("i32.ge", RELOP (Ty (Ty_bitv 32), Ge)) + ; ("i32.to_bool", CVTOP (Ty (Ty_bitv 32), ToBool)) + ; ("i32.of_bool", CVTOP (Ty (Ty_bitv 32), OfBool)) + ; ("i32.trunc_f32_s", CVTOP (Ty (Ty_bitv 32), TruncSF32)) + ; ("i32.trunc_f32_u", CVTOP (Ty (Ty_bitv 32), TruncUF32)) + ; ("i32.trunc_f64_s", CVTOP (Ty (Ty_bitv 32), TruncSF64)) + ; ("i32.trunc_f64_u", CVTOP (Ty (Ty_bitv 32), TruncUF64)) + ; ("i32.reinterpret_float", CVTOP (Ty (Ty_bitv 32), Reinterpret_float)) + ; ("i32.wrap_i64", CVTOP (Ty (Ty_bitv 32), WrapI64)) + ; ("i32.extend_i16_s", CVTOP (Ty (Ty_bitv 32), Sign_extend 16)) + ; ("i32.extend_i16_u", CVTOP (Ty (Ty_bitv 32), Zero_extend 16)) + ; ("i32.extend_i24_s", CVTOP (Ty (Ty_bitv 32), Sign_extend 24)) + ; ("i32.extend_i24_u", CVTOP (Ty (Ty_bitv 32), Zero_extend 24)) + ; ("i64.neg", UNARY (Utils_parse.U (Ty_bitv 64, Neg))) + ; ("i64.clz", UNARY (Utils_parse.U (Ty_bitv 64, Clz))) + ; ("i64.not", UNARY (Utils_parse.U (Ty_bitv 64, Not))) + ; ("i64.add", BINARY (Ty (Ty_bitv 64), Add)) + ; ("i64.sub", BINARY (Ty (Ty_bitv 64), Sub)) + ; ("i64.div", BINARY (Ty (Ty_bitv 64), Div)) + ; ("i64.div_u", BINARY (Ty (Ty_bitv 64), DivU)) + ; ("i64.and", BINARY (Ty (Ty_bitv 64), And)) + ; ("i64.or", BINARY (Ty (Ty_bitv 64), Or)) + ; ("i64.xor", BINARY (Ty (Ty_bitv 64), Xor)) + ; ("i64.mul", BINARY (Ty (Ty_bitv 64), Mul)) + ; ("i64.shl", BINARY (Ty (Ty_bitv 64), Shl)) + ; ("i64.shr", BINARY (Ty (Ty_bitv 64), ShrA)) + ; ("i64.shr_u", BINARY (Ty (Ty_bitv 64), ShrL)) + ; ("i64.rem", BINARY (Ty (Ty_bitv 64), Rem)) + ; ("i64.rem_u", BINARY (Ty (Ty_bitv 64), RemU)) + ; ("i64.eq", RELOP (Ty Ty_bool, Eq)) + ; ("i64.ne", RELOP (Ty Ty_bool, Ne)) + ; ("i64.lt_u", RELOP (Ty (Ty_bitv 64), LtU)) + ; ("i64.lt", RELOP (Ty (Ty_bitv 64), Lt)) + ; ("i64.le_u", RELOP (Ty (Ty_bitv 64), LeU)) + ; ("i64.le", RELOP (Ty (Ty_bitv 64), Le)) + ; ("i64.gt_u", RELOP (Ty (Ty_bitv 64), GtU)) + ; ("i64.gt", RELOP (Ty (Ty_bitv 64), Gt)) + ; ("i64.ge_u", RELOP (Ty (Ty_bitv 64), GeU)) + ; ("i64.ge", RELOP (Ty (Ty_bitv 64), Ge)) + ; ("i64.trunc_f32_s", CVTOP (Ty (Ty_bitv 64), TruncSF32)) + ; ("i64.trunc_f32_u", CVTOP (Ty (Ty_bitv 64), TruncUF32)) + ; ("i64.trunc_f64_s", CVTOP (Ty (Ty_bitv 64), TruncSF64)) + ; ("i64.trunc_f64_u", CVTOP (Ty (Ty_bitv 64), TruncUF64)) + ; ("i64.reinterpret_float", CVTOP (Ty (Ty_bitv 64), Reinterpret_float)) + ; ("i64.extend_i32_s", CVTOP (Ty (Ty_bitv 64), Sign_extend 32)) + ; ("i64.extend_i32_u", CVTOP (Ty (Ty_bitv 64), Zero_extend 32)) + ; ("f32.neg", UNARY (Utils_parse.U (Ty_fp 32, Neg))) + ; ("f32.abs", UNARY (Utils_parse.U (Ty_fp 32, Abs))) + ; ("f32.sqrt", UNARY (Utils_parse.U (Ty_fp 32, Sqrt))) + ; ("f32.nearest",UNARY (Utils_parse.U (Ty_fp 32, Nearest))) + ; ("f32.is_nan", UNARY (Utils_parse.U (Ty_fp 32, Is_nan))) + ; ("f32.ceil", UNARY (Utils_parse.U (Ty_fp 32, Ceil))) + ; ("f32.floor", UNARY (Utils_parse.U (Ty_fp 32, Floor))) + ; ("f32.trunc", UNARY (Utils_parse.U (Ty_fp 32, Trunc))) + ; ("f32.add", BINARY (Ty (Ty_fp 32), Add)) + ; ("f32.sub", BINARY (Ty (Ty_fp 32), Sub)) + ; ("f32.mul", BINARY (Ty (Ty_fp 32), Mul)) + ; ("f32.div", BINARY (Ty (Ty_fp 32), Div)) + ; ("f32.min", BINARY (Ty (Ty_fp 32), Min)) + ; ("f32.max", BINARY (Ty (Ty_fp 32), Max)) + ; ("f32.rem", BINARY (Ty (Ty_fp 32), Rem)) + ; ("f32.eq", RELOP (Ty (Ty_fp 32), Eq)) + ; ("f32.ne", RELOP (Ty (Ty_fp 32), Ne)) + ; ("f32.lt", RELOP (Ty (Ty_fp 32), Lt)) + ; ("f32.le", RELOP (Ty (Ty_fp 32), Le)) + ; ("f32.gt", RELOP (Ty (Ty_fp 32), Gt)) + ; ("f32.ge", RELOP (Ty (Ty_fp 32), Ge)) + ; ("f32.convert_i32_s", CVTOP (Ty (Ty_fp 32), ConvertSI32)) + ; ("f32.convert_i32_u", CVTOP (Ty (Ty_fp 32), ConvertUI32)) + ; ("f32.convert_i64_s", CVTOP (Ty (Ty_fp 32), ConvertSI32)) + ; ("f32.demote_f64", CVTOP (Ty (Ty_fp 32), DemoteF64)) + ; ("f32.reinterpret_int", CVTOP (Ty (Ty_fp 32), Reinterpret_int)) + ; ("f64.neg", UNARY (Utils_parse.U (Ty_fp 64, Neg))) + ; ("f64.abs", UNARY (Utils_parse.U (Ty_fp 64, Abs))) + ; ("f64.sqrt", UNARY (Utils_parse.U (Ty_fp 64, Sqrt))) + ; ("f64.nearest",UNARY (Utils_parse.U (Ty_fp 64, Nearest))) + ; ("f64.is_nan", UNARY (Utils_parse.U (Ty_fp 64, Is_nan))) + ; ("f64.ceil", UNARY (Utils_parse.U (Ty_fp 32, Ceil))) + ; ("f64.floor", UNARY (Utils_parse.U (Ty_fp 32, Floor))) + ; ("f64.trunc", UNARY (Utils_parse.U (Ty_fp 32, Trunc))) + ; ("f64.add", BINARY (Ty (Ty_fp 64), Add)) + ; ("f64.sub", BINARY (Ty (Ty_fp 64), Sub)) + ; ("f64.mul", BINARY (Ty (Ty_fp 64), Mul)) + ; ("f64.div", BINARY (Ty (Ty_fp 64), Div)) + ; ("f64.min", BINARY (Ty (Ty_fp 64), Min)) + ; ("f64.max", BINARY (Ty (Ty_fp 64), Max)) + ; ("f64.rem", BINARY (Ty (Ty_fp 64), Rem)) + ; ("f64.eq", RELOP (Ty (Ty_fp 64), Eq)) + ; ("f64.ne", RELOP (Ty (Ty_fp 64), Ne)) + ; ("f64.lt", RELOP (Ty (Ty_fp 64), Lt)) + ; ("f64.le", RELOP (Ty (Ty_fp 64), Le)) + ; ("f64.gt", RELOP (Ty (Ty_fp 64), Gt)) + ; ("f64.ge", RELOP (Ty (Ty_fp 64), Ge)) + ; ("f64.convert_i32_s", CVTOP (Ty (Ty_fp 64), ConvertSI32)) + ; ("f64.convert_i32_u", CVTOP (Ty (Ty_fp 64), ConvertUI32)) + ; ("f64.convert_i64_s", CVTOP (Ty (Ty_fp 64), ConvertSI32)) + ; ("f64.promote_f32", CVTOP (Ty (Ty_fp 64), PromoteF32)) + ; ("f64.reinterpret_int", CVTOP (Ty (Ty_fp 64), Reinterpret_int)) ; ("extract", EXTRACT) ; ("++", CONCAT) ; ("Ptr", PTR) diff --git a/src/smtml/mappings.ml b/src/smtml/mappings.ml index 2b7d8ba9..4df8e9c4 100644 --- a/src/smtml/mappings.ml +++ b/src/smtml/mappings.ml @@ -55,19 +55,19 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct let f64_to_i64 = M.Func.make "f64_to_i64" [ f64 ] i64 let get_type = function - | Ty_int -> M.Types.int - | Ty_real -> M.Types.real - | Ty_bool -> M.Types.bool - | Ty_str -> M.Types.string - | Ty_bitv 8 -> i8 - | Ty_bitv 32 -> i32 - | Ty_bitv 64 -> i64 - | Ty_bitv n -> M.Types.bitv n - | Ty_fp 32 -> f32 - | Ty_fp 64 -> f64 - | Ty_roundingMode -> M.Types.roundingMode - | Ty_regexp -> M.Types.regexp - | (Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none) as ty -> + | Ty Ty_int -> M.Types.int + | Ty Ty_real -> M.Types.real + | Ty Ty_bool -> M.Types.bool + | Ty Ty_str -> M.Types.string + | Ty (Ty_bitv 8) -> i8 + | Ty (Ty_bitv 32) -> i32 + | Ty (Ty_bitv 64) -> i64 + | Ty (Ty_bitv n) -> M.Types.bitv n + | Ty (Ty_fp 32) -> f32 + | Ty (Ty_fp 64) -> f64 + | Ty Ty_roundingMode -> M.Types.roundingMode + | Ty Ty_regexp -> M.Types.regexp + | Ty (Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none) as ty -> Fmt.failwith "Unsupported theory: %a@." Ty.pp ty let make_symbol (ctx : symbol_ctx) (s : Symbol.t) : symbol_ctx * M.term = @@ -87,10 +87,8 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct let false_ = M.false_ - let unop = function - | Unop.Not -> M.not_ - | op -> - Fmt.failwith {|Bool: Unsupported Z3 unop operator "%a"|} Unop.pp op + let unop (op : [ `Ty_bool ] Ty.Unop.op) e = + match op with Unop.Not -> M.not_ e let binop = function | Binop.And -> M.and_ @@ -126,9 +124,11 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct module Int_impl = struct let v i = M.int i [@@inline] - let unop = function - | Unop.Neg -> M.Int.neg - | op -> Fmt.failwith {|Int: Unsupported unop operator "%a"|} Unop.pp op + let unop (op : [ `Ty_int ] Ty.Unop.op) e = + match op with + | Neg -> M.Int.neg e + | Not | Abs -> + Fmt.failwith {|Int: Unsupported unop operator "%a"|} Unop.pp (U op) let binop = function | Binop.Add -> M.Int.add @@ -161,18 +161,18 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct module Real_impl = struct let v f = M.real f [@@inline] - let unop op e = + let unop (op : [ `Ty_real ] Ty.Unop.op) e = let open M in match op with - | Unop.Neg -> Real.neg e + | Neg -> Real.neg e | Abs -> ite (Real.gt e (real 0.)) e (Real.neg e) | Sqrt -> Real.pow e (v 0.5) | Ceil -> let x_int = M.Real.to_int e in ite (eq (Int.to_real x_int) e) x_int (Int.add x_int (int 1)) | Floor -> Real.to_int e - | Nearest | Is_nan | _ -> - Fmt.failwith {|Real: Unsupported unop operator "%a"|} Unop.pp op + | Nearest | Trunc | Is_nan -> + Fmt.failwith {|Real: Unsupported unop operator "%a"|} Unop.pp (U op) let binop op e1 e2 = match op with @@ -209,12 +209,13 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct module String_impl = struct let v s = M.String.v s [@@inline] - let unop op e = + let unop (op : [ `Ty_str ] Ty.Unop.op) e = match op with - | Unop.Length -> M.String.length e + | Length -> M.String.length e | Trim -> M.Func.apply str_trim [ e ] - | op -> - Fmt.failwith {|String: Unsupported unop operator "%a"|} Unop.pp op + | Regexp_comp | Regexp_opt | Regexp_plus | Regexp_loop _ | Regexp_star + -> + Fmt.failwith {|String: Unsupported unop operator "%a"|} Unop.pp (U op) let binop op e1 e2 = match op with @@ -261,15 +262,13 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct end module Regexp_impl = struct - let unop op e = + let unop (op : [ `Ty_regexp ] Ty.Unop.op) e = match op with - | Unop.Regexp_star -> M.Re.star e + | Regexp_star -> M.Re.star e | Regexp_plus -> M.Re.plus e | Regexp_opt -> M.Re.opt e | Regexp_comp -> M.Re.comp e | Regexp_loop (i1, i2) -> M.Re.loop e i1 i2 - | op -> - Fmt.failwith {|Regexp: Unsupported unop operator "%a"|} Unop.pp op let binop op e1 e2 = match op with @@ -359,14 +358,13 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct in loop 0 (v @@ Ixx.of_int 0) - let unop = function - | Unop.Clz -> clz - | Ctz -> ctz - | Popcnt -> popcnt - | Neg -> Bitv.neg - | Not -> Bitv.lognot - | op -> - Fmt.failwith {|Bitv: Unsupported unary operator "%a"|} Unop.pp op + let unop (op : [ `Ty_bitv ] Ty.Unop.op) e = + match op with + | Unop.Clz -> clz e + | Ctz -> ctz e + | Popcnt -> popcnt e + | Neg -> Bitv.neg e + | Not -> Bitv.lognot e let binop = function | Binop.Add -> Bitv.add @@ -481,9 +479,9 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct open M include F - let unop op e = + let unop (op : [ `Ty_fp ] Ty.Unop.op) e = match op with - | Unop.Neg -> Float.neg e + | Neg -> Float.neg e | Abs -> Float.abs e | Sqrt -> Float.sqrt ~rm:Float.Rounding_mode.rne e | Is_normal -> Float.is_normal e @@ -497,7 +495,6 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct | Floor -> Float.round_to_integral ~rm:Float.Rounding_mode.rtn e | Trunc -> Float.round_to_integral ~rm:Float.Rounding_mode.rtz e | Nearest -> Float.round_to_integral ~rm:Float.Rounding_mode.rne e - | _ -> Fmt.failwith {|Fp: Unsupported unary operator "%a"|} Unop.pp op let binop op e1 e2 = match op with @@ -591,81 +588,106 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct | Bitv bv -> M.Bitv.v (Bitvector.to_string bv) (Bitvector.numbits bv) | List _ | App _ | Unit | Nothing -> assert false - let unop = function - | Ty.Ty_int -> Int_impl.unop - | Ty.Ty_real -> Real_impl.unop - | Ty.Ty_bool -> Bool_impl.unop - | Ty.Ty_str -> String_impl.unop - | Ty.Ty_regexp -> Regexp_impl.unop - | Ty.Ty_bitv 8 -> I8.unop - | Ty.Ty_bitv 32 -> I32.unop - | Ty.Ty_bitv 64 -> I64.unop - | Ty.Ty_fp 32 -> Float32_impl.unop - | Ty.Ty_fp 64 -> Float64_impl.unop - | Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none - | Ty_roundingMode -> + let unop : type a. a Ty.ty -> a Ty.Unop.op -> M.term -> M.term = + fun ty op e -> + match (ty, op) with + | Ty_int, ((Neg | Not | Abs) as op) -> Int_impl.unop op e + | ( Ty_real + , ((Neg | Abs | Sqrt | Is_nan | Ceil | Floor | Trunc | Nearest) as op) ) + -> + Real_impl.unop op e + | Ty_bool, Not -> Bool_impl.unop Not e + | ( Ty_str + , ( ( Length | Trim | Regexp_star | Regexp_loop _ | Regexp_plus + | Regexp_opt | Regexp_comp ) as op ) ) -> + String_impl.unop op e + | ( Ty_regexp + , ( ( Regexp_star | Regexp_loop _ | Regexp_plus | Regexp_opt + | Regexp_comp ) as op ) ) -> + Regexp_impl.unop op e + | Ty_bitv n, ((Neg | Not | Clz | Ctz | Popcnt) as op) -> + if n = 8 then I8.unop op e + else if n = 32 then I32.unop op e + else begin + assert (n = 64); + I64.unop op e + end + | ( Ty_fp n + , ( ( Neg | Abs | Sqrt | Is_normal | Is_subnormal | Is_negative + | Is_positive | Is_infinite | Is_nan | Is_zero | Ceil | Floor + | Trunc | Nearest ) as op ) ) -> + if n = 32 then Float32_impl.unop op e + else begin + assert (n = 64); + Float64_impl.unop op e + end + | (Ty_list | Ty_app | Ty_unit | Ty_roundingMode | Ty_none), _ -> assert false let binop = function - | Ty.Ty_int -> Int_impl.binop - | Ty.Ty_real -> Real_impl.binop - | Ty.Ty_bool -> Bool_impl.binop - | Ty.Ty_str -> String_impl.binop - | Ty.Ty_regexp -> Regexp_impl.binop - | Ty.Ty_bitv 8 -> I8.binop - | Ty.Ty_bitv 32 -> I32.binop - | Ty.Ty_bitv 64 -> I64.binop - | Ty.Ty_fp 32 -> Float32_impl.binop - | Ty.Ty_fp 64 -> Float64_impl.binop - | Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none - | Ty_roundingMode -> + | Ty Ty_int -> Int_impl.binop + | Ty Ty_real -> Real_impl.binop + | Ty Ty_bool -> Bool_impl.binop + | Ty Ty_str -> String_impl.binop + | Ty Ty_regexp -> Regexp_impl.binop + | Ty (Ty_bitv 8) -> I8.binop + | Ty (Ty_bitv 32) -> I32.binop + | Ty (Ty_bitv 64) -> I64.binop + | Ty (Ty_fp 32) -> Float32_impl.binop + | Ty (Ty_fp 64) -> Float64_impl.binop + | Ty + ( Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_roundingMode + | Ty_none ) -> assert false let triop = function - | Ty.Ty_int | Ty.Ty_real -> assert false - | Ty.Ty_bool -> Bool_impl.triop - | Ty.Ty_str -> String_impl.triop - | Ty.Ty_bitv 8 -> I8.triop - | Ty.Ty_bitv 32 -> I32.triop - | Ty.Ty_bitv 64 -> I64.triop - | Ty.Ty_fp 32 -> Float32_impl.triop - | Ty.Ty_fp 64 -> Float64_impl.triop - | Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none - | Ty_regexp | Ty_roundingMode -> + | Ty (Ty_int | Ty_real) -> assert false + | Ty Ty_bool -> Bool_impl.triop + | Ty Ty_str -> String_impl.triop + | Ty (Ty_bitv 8) -> I8.triop + | Ty (Ty_bitv 32) -> I32.triop + | Ty (Ty_bitv 64) -> I64.triop + | Ty (Ty_fp 32) -> Float32_impl.triop + | Ty (Ty_fp 64) -> Float64_impl.triop + | Ty + ( Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none + | Ty_regexp | Ty_roundingMode ) -> assert false let relop = function - | Ty.Ty_int -> Int_impl.relop - | Ty.Ty_real -> Real_impl.relop - | Ty.Ty_bool -> Bool_impl.relop - | Ty.Ty_str -> String_impl.relop - | Ty.Ty_bitv 8 -> I8.relop - | Ty.Ty_bitv 32 -> I32.relop - | Ty.Ty_bitv 64 -> I64.relop - | Ty.Ty_fp 32 -> Float32_impl.relop - | Ty.Ty_fp 64 -> Float64_impl.relop - | Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none - | Ty_regexp | Ty_roundingMode -> + | Ty Ty_int -> Int_impl.relop + | Ty Ty_real -> Real_impl.relop + | Ty Ty_bool -> Bool_impl.relop + | Ty Ty_str -> String_impl.relop + | Ty (Ty_bitv 8) -> I8.relop + | Ty (Ty_bitv 32) -> I32.relop + | Ty (Ty_bitv 64) -> I64.relop + | Ty (Ty_fp 32) -> Float32_impl.relop + | Ty (Ty_fp 64) -> Float64_impl.relop + | Ty + ( Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none + | Ty_regexp | Ty_roundingMode ) -> assert false let cvtop = function - | Ty.Ty_int -> Int_impl.cvtop - | Ty.Ty_real -> Real_impl.cvtop - | Ty.Ty_bool -> Bool_impl.cvtop - | Ty.Ty_str -> String_impl.cvtop - | Ty.Ty_bitv 8 -> I8.cvtop - | Ty.Ty_bitv 32 -> I32.cvtop - | Ty.Ty_bitv 64 -> I64.cvtop - | Ty.Ty_fp 32 -> Float32_impl.cvtop - | Ty.Ty_fp 64 -> Float64_impl.cvtop - | Ty.Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none - | Ty_regexp | Ty_roundingMode -> + | Ty Ty_int -> Int_impl.cvtop + | Ty Ty_real -> Real_impl.cvtop + | Ty Ty_bool -> Bool_impl.cvtop + | Ty Ty_str -> String_impl.cvtop + | Ty (Ty_bitv 8) -> I8.cvtop + | Ty (Ty_bitv 32) -> I32.cvtop + | Ty (Ty_bitv 64) -> I64.cvtop + | Ty (Ty_fp 32) -> Float32_impl.cvtop + | Ty (Ty_fp 64) -> Float64_impl.cvtop + | Ty + ( Ty_bitv _ | Ty_fp _ | Ty_list | Ty_app | Ty_unit | Ty_none + | Ty_regexp | Ty_roundingMode ) -> assert false let naryop = function - | Ty.Ty_str -> String_impl.naryop - | Ty.Ty_bool -> Bool_impl.naryop - | Ty.Ty_regexp -> Regexp_impl.naryop + | Ty Ty_str -> String_impl.naryop + | Ty Ty_bool -> Bool_impl.naryop + | Ty Ty_regexp -> Regexp_impl.naryop | ty -> Fmt.failwith "Naryop for type \"%a\" not implemented" Ty.pp ty let get_rounding_mode ctx rm = @@ -797,7 +819,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct (ctx, e :: es) ) (ctx, []) es - let value_of_term ?ctx model ty term = + let value_of_term ?ctx model (Ty ty) term = let v = match M.Model.eval ?ctx ~completion:true model term with | None -> assert false diff --git a/src/smtml/num.ml b/src/smtml/num.ml index 75ec87e7..9730778d 100644 --- a/src/smtml/num.ml +++ b/src/smtml/num.ml @@ -12,7 +12,7 @@ type printer = ] let type_of (n : t) = - match n with F32 _ -> Ty.(Ty_fp 32) | F64 _ -> Ty.(Ty_fp 64) + match n with F32 _ -> Ty.Ty (Ty_fp 32) | F64 _ -> Ty.Ty (Ty_fp 64) let compare n1 n2 = match (n1, n2) with @@ -47,7 +47,7 @@ let pp fmt v = !printer fmt v let to_string (n : t) : string = Fmt.str "%a" pp n -let of_string (cast : Ty.t) value = +let of_string (Ty cast : Ty.t) value = match cast with | Ty_fp 32 -> ( match float_of_string_opt value with diff --git a/src/smtml/parser.mly b/src/smtml/parser.mly index 1be25f87..bbe08598 100644 --- a/src/smtml/parser.mly +++ b/src/smtml/parser.mly @@ -29,7 +29,7 @@ let get_bind x = Hashtbl.find_opt varmap x %token BOOL %token STR %token SYMBOL -%token UNARY +%token UNARY %token BINARY %token TERNARY %token RELOP @@ -68,18 +68,19 @@ let s_expr := let paren_op := | PTR; LPAREN; _ = TYPE; x = NUM; RPAREN; offset = s_expr; { Expr.ptr (Int32.of_int x) offset } - | (ty, op) = UNARY; e = s_expr; - { Expr.unop ty op e } + | op = UNARY; e = s_expr; + { let U (ty, op) = op in Expr.unop ty op e } | (ty, op) = BINARY; e1 = s_expr; e2 = s_expr; - { Expr.binop ty op e1 e2 } + { let Ty ty = ty in Expr.binop ty op e1 e2 } | (ty, op) = TERNARY; e1 = s_expr; e2 = s_expr; e3 = s_expr; - { Expr.triop ty op e1 e2 e3 } + { let Ty ty = ty in Expr.triop ty op e1 e2 e3 } | (ty, op) = CVTOP; e = s_expr; - { Expr.cvtop ty op e } + { let Ty ty = ty in Expr.cvtop ty op e } | (ty, op) = RELOP; e1 = s_expr; e2 = s_expr; - { Expr.relop ty op e1 e2 } + { let Ty ty = ty in Expr.relop ty op e1 e2 } | (ty, op) = NARY; es = list(s_expr); - { Expr.naryop ty op es } + { let Ty ty = ty in + Expr.naryop ty op es } | EXTRACT; ~ = s_expr; l = NUM; h = NUM; { Expr.extract s_expr ~high:h ~low:l } | CONCAT; e1 = s_expr; e2 = s_expr; @@ -93,14 +94,14 @@ let spec_constant := | LPAREN; ty = TYPE; x = NUM; RPAREN; { match ty with - | Ty_bitv 32 -> Bitv (Bitvector.of_int32 (Int32.of_int x)) - | Ty_bitv 64 -> Bitv (Bitvector.of_int64 (Int64.of_int x)) + | Ty (Ty_bitv 32) -> Bitv (Bitvector.of_int32 (Int32.of_int x)) + | Ty (Ty_bitv 64) -> Bitv (Bitvector.of_int64 (Int64.of_int x)) | _ -> Fmt.failwith "invalid bitv type" } | LPAREN; ty = TYPE; x = DEC; RPAREN; { match ty with - | Ty_fp 32 -> Num (F32 (Int32.bits_of_float x)) - | Ty_fp 64 -> Num (F64 (Int64.bits_of_float x)) + | Ty (Ty_fp 32) -> Num (F32 (Int32.bits_of_float x)) + | Ty (Ty_fp 64) -> Num (F64 (Int64.bits_of_float x)) | _ -> Fmt.failwith "invalid fp type" } diff --git a/src/smtml/rewrite.ml b/src/smtml/rewrite.ml index 8be6e79b..394c59c5 100644 --- a/src/smtml/rewrite.ml +++ b/src/smtml/rewrite.ml @@ -22,21 +22,21 @@ let debug fmt k = if debug then k (Fmt.epr fmt) (* FIXME: This is a very basic way to infer types. I'm surprised it even works *) let rewrite_ty unknown_ty tys = match (unknown_ty, tys) with - | Ty.Ty_none, [ ty ] -> + | Ty.Ty Ty_none, [ ty ] -> debug " rewrite_ty: %a -> %a@." (fun k -> k Ty.pp unknown_ty Ty.pp ty); ty - | Ty_none, [ ty1; ty2 ] -> + | Ty Ty_none, [ ty1; ty2 ] -> debug " rewrite_ty: %a -> (%a %a)@." (fun k -> k Ty.pp unknown_ty Ty.pp ty1 Ty.pp ty2 ); assert (Ty.equal ty1 ty2); ty1 - | Ty_none, ty1 :: ty2 :: [ ty3 ] -> + | Ty Ty_none, ty1 :: ty2 :: [ ty3 ] -> debug " rewrite_ty: %a ->(%a %a %a)@." (fun k -> k Ty.pp unknown_ty Ty.pp ty1 Ty.pp ty2 Ty.pp ty3 ); assert (Ty.equal ty1 ty2); assert (Ty.equal ty2 ty3); ty1 - | Ty_none, _ -> assert false + | Ty Ty_none, _ -> assert false | ty, _ -> ty (** Propagates types in [type_map] and inlines [Let_in] binders *) @@ -48,7 +48,7 @@ let rec rewrite_expr (type_map, expr_map) hte = Expr.ptr base (rewrite_expr (type_map, expr_map) offset) | Symbol sym -> begin (* Avoid rewriting well-typed symbols already *) - if not (Ty.equal Ty_none (Symbol.type_of sym)) then hte + if not (Ty.equal (Ty Ty_none) (Symbol.type_of sym)) then hte else match Symb_map.find_opt sym type_map with | None -> ( @@ -64,21 +64,21 @@ let rec rewrite_expr (type_map, expr_map) hte = let rm = rewrite_expr (type_map, expr_map) rm in let a = rewrite_expr (type_map, expr_map) a in let b = rewrite_expr (type_map, expr_map) b in - let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b ] in + let ty = rewrite_ty (Ty Ty_none) [ Expr.ty a; Expr.ty b ] in Expr.app { sym with ty } [ rm; a; b ] | App (({ name = Simple "fp.fma"; _ } as sym), [ rm; a; b; c ]) -> let rm = rewrite_expr (type_map, expr_map) rm in let a = rewrite_expr (type_map, expr_map) a in let b = rewrite_expr (type_map, expr_map) b in let c = rewrite_expr (type_map, expr_map) c in - let ty = rewrite_ty Ty_none [ Expr.ty a; Expr.ty b; Expr.ty c ] in + let ty = rewrite_ty (Ty Ty_none) [ Expr.ty a; Expr.ty b; Expr.ty c ] in Expr.app { sym with ty } [ rm; a; b; c ] | App ( ({ name = Simple ("fp.sqrt" | "fp.roundToIntegral"); _ } as sym) , [ rm; a ] ) -> let rm = rewrite_expr (type_map, expr_map) rm in let a = rewrite_expr (type_map, expr_map) a in - let ty = rewrite_ty Ty_none [ Expr.ty a ] in + let ty = rewrite_ty (Ty Ty_none) [ Expr.ty a ] in Expr.app { sym with ty } [ rm; a ] | App (sym, htes) -> let sym = @@ -87,34 +87,61 @@ let rec rewrite_expr (type_map, expr_map) hte = | Some ty -> { sym with ty } in Expr.app sym (List.map (rewrite_expr (type_map, expr_map)) htes) - | Unop (ty, op, hte) -> + | Unop (ty, op, hte) -> begin let hte = rewrite_expr (type_map, expr_map) hte in - let ty = rewrite_ty ty [ Expr.ty hte ] in - Expr.unop ty op hte + let (Ty ty) = rewrite_ty (Ty ty) [ Expr.ty hte ] in + match (ty, op) with + | Ty_bool, Not -> Expr.unop Ty_bool Not hte + | Ty_int, ((Neg | Not | Abs) as op) -> Expr.unop Ty_int op hte + | Ty_real, ((Neg | Abs | Sqrt | Ceil | Floor | Trunc | Nearest) as op) -> + Expr.unop Ty_real op hte + | ( Ty_str + , ( ( Length | Trim | Regexp_star | Regexp_loop _ | Regexp_plus + | Regexp_opt | Regexp_comp ) as op ) ) -> + Expr.unop Ty_str op hte + | (Ty_bitv _ as ty), ((Neg | Not | Clz | Ctz | Popcnt) as op) -> + Expr.unop ty op hte + | ( (Ty_fp _ as ty) + , ( ( Neg | Abs | Sqrt | Is_normal | Is_subnormal | Is_negative + | Is_positive | Is_infinite | Is_nan | Is_zero | Ceil | Floor | Trunc + | Nearest ) as op ) ) -> + Expr.unop ty op hte + | Ty_list, ((Head | Tail | Reverse | Length) as op) -> + Expr.unop Ty_list op hte + | ( Ty_regexp + , ( (Regexp_star | Regexp_loop _ | Regexp_plus | Regexp_opt | Regexp_comp) + as op ) ) -> + Expr.unop Ty_regexp op hte + | (Ty_app | Ty_none | Ty_unit | Ty_roundingMode), _ -> + Fmt.failwith "TypeError: Unop typed with theory with no opeartors" + | _ -> + Fmt.failwith "TypeError: Unop (%a, %a)" Ty.pp (Ty ty) Ty.Unop.pp (U op) + end | Binop (ty, op, hte1, hte2) -> let hte1 = rewrite_expr (type_map, expr_map) hte1 in let hte2 = rewrite_expr (type_map, expr_map) hte2 in - let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in + let (Ty ty) = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in Expr.binop ty op hte1 hte2 - | Triop (ty, op, hte1, hte2, hte3) -> + | Triop (Ty ty, op, hte1, hte2, hte3) -> let hte1 = rewrite_expr (type_map, expr_map) hte1 in let hte2 = rewrite_expr (type_map, expr_map) hte2 in let hte3 = rewrite_expr (type_map, expr_map) hte3 in Expr.triop ty op hte1 hte2 hte3 - | Relop (ty, ((Eq | Ne) as op), hte1, hte2) when not (Ty.equal Ty_none ty) -> + | Relop ((Ty ty as ty_), ((Eq | Ne) as op), hte1, hte2) + when not (Ty.equal (Ty Ty_none) ty_) -> let hte1 = rewrite_expr (type_map, expr_map) hte1 in let hte2 = rewrite_expr (type_map, expr_map) hte2 in Expr.relop ty op hte1 hte2 | Relop (ty, op, hte1, hte2) -> let hte1 = rewrite_expr (type_map, expr_map) hte1 in let hte2 = rewrite_expr (type_map, expr_map) hte2 in - let ty = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in + let (Ty ty) = rewrite_ty ty [ Expr.ty hte1; Expr.ty hte2 ] in Expr.relop ty op hte1 hte2 | Cvtop (ty, op, hte) -> let hte = rewrite_expr (type_map, expr_map) hte in - let ty = rewrite_ty ty [ Expr.ty hte ] in + let (Ty ty) = rewrite_ty ty [ Expr.ty hte ] in Expr.cvtop ty op hte - | Naryop (ty, op, htes) -> + | Naryop (Ty ty, op, htes) -> let htes = List.map (rewrite_expr (type_map, expr_map)) htes in Expr.naryop ty op htes | Extract (hte, h, l) -> diff --git a/src/smtml/smtlib.ml b/src/smtml/smtlib.ml index ece504eb..866ab86a 100644 --- a/src/smtml/smtlib.ml +++ b/src/smtml/smtlib.ml @@ -32,13 +32,13 @@ module Term = struct match (Symbol.namespace id, Symbol.name id) with | Sort, Simple name -> begin match name with - | "Int" -> Expr.symbol { id with ty = Ty_int } - | "Real" -> Expr.symbol { id with ty = Ty_real } - | "Bool" -> Expr.symbol { id with ty = Ty_bool } - | "String" -> Expr.symbol { id with ty = Ty_str } - | "Float32" -> Expr.symbol { id with ty = Ty_fp 32 } - | "Float64" -> Expr.symbol { id with ty = Ty_fp 64 } - | "RoundingMode" -> Expr.symbol { id with ty = Ty_roundingMode } + | "Int" -> Expr.symbol { id with ty = Ty Ty_int } + | "Real" -> Expr.symbol { id with ty = Ty Ty_real } + | "Bool" -> Expr.symbol { id with ty = Ty Ty_bool } + | "String" -> Expr.symbol { id with ty = Ty Ty_str } + | "Float32" -> Expr.symbol { id with ty = Ty (Ty_fp 32) } + | "Float64" -> Expr.symbol { id with ty = Ty (Ty_fp 64) } + | "RoundingMode" -> Expr.symbol { id with ty = Ty Ty_roundingMode } | _ -> begin match Hashtbl.find_opt custom_sorts name with | Some ty -> Expr.symbol { id with ty } @@ -52,11 +52,11 @@ module Term = struct match (basename, indices) with | "BitVec", [ n ] -> ( match int_of_string_opt n with - | Some n -> Expr.symbol { id with ty = Ty_bitv n } + | Some n -> Expr.symbol { id with ty = Ty (Ty_bitv n) } | None -> Fmt.failwith "Invalid bitvector size" ) | "FloatingPoint", [ e; s ] -> ( match (int_of_string_opt e, int_of_string_opt s) with - | Some e, Some s -> Expr.symbol { id with ty = Ty_fp (e + s) } + | Some e, Some s -> Expr.symbol { id with ty = Ty (Ty_fp (e + s)) } | _ -> Fmt.failwith "Invalid floating point size" ) | _ -> Fmt.failwith "%acould not parse indexed sort:%a %a@." pp_loc loc @@ -71,9 +71,9 @@ module Term = struct | "roundNearestTiesToEven" | "RNE" | "roundNearestTiesToAway" | "RNA" | "roundTowardPositive" | "RTP" | "roundTowardNegative" | "RTN" | "roundTowardZero" | "RTZ" -> - Expr.symbol { id with ty = Ty_roundingMode } + Expr.symbol { id with ty = Ty Ty_roundingMode } | "re.all" | "re.allchar" | "re.none" -> - Expr.symbol { id with ty = Ty_regexp } + Expr.symbol { id with ty = Ty Ty_regexp } | _ -> Expr.symbol id end | Term, Indexed { basename = base; indices } -> begin diff --git a/src/smtml/symbol.ml b/src/smtml/symbol.ml index 77353aff..3640bd3f 100644 --- a/src/smtml/symbol.ml +++ b/src/smtml/symbol.ml @@ -71,10 +71,10 @@ let make ty name = name @: ty let make3 ty name namespace = { ty; name; namespace } -let mk namespace name = { ty = Ty_none; name = Name.simple name; namespace } +let mk namespace name = { ty = Ty Ty_none; name = Name.simple name; namespace } let indexed namespace basename indices = - { ty = Ty_none; name = Name.indexed basename indices; namespace } + { ty = Ty Ty_none; name = Name.indexed basename indices; namespace } let pp_namespace fmt = function | Attr -> Fmt.string fmt "attr" diff --git a/src/smtml/ty.ml b/src/smtml/ty.ml index 44a8ea29..ff8de31d 100644 --- a/src/smtml/ty.ml +++ b/src/smtml/ty.ml @@ -7,21 +7,23 @@ type _ cast = | C32 : int32 cast | C64 : int64 cast -type t = - | Ty_app - | Ty_bitv of int - | Ty_bool - | Ty_fp of int - | Ty_int - | Ty_list - | Ty_none - | Ty_real - | Ty_str - | Ty_unit - | Ty_regexp - | Ty_roundingMode +type _ ty = + | Ty_app : [> `Ty_app ] ty + | Ty_bitv : int -> [> `Ty_bitv ] ty + | Ty_bool : [> `Ty_bool ] ty + | Ty_fp : int -> [> `Ty_fp ] ty + | Ty_int : [> `Ty_int ] ty + | Ty_list : [> `Ty_list ] ty + | Ty_none : [> `Ty_none ] ty + | Ty_real : [> `Ty_real ] ty + | Ty_str : [> `Ty_str ] ty + | Ty_unit : [> `Ty_unit ] ty + | Ty_regexp : [> `Ty_regexp ] ty + | Ty_roundingMode : [> `Ty_roundingMode ] ty -let discr = function +type t = Ty : 'a ty -> t + +let discr : type a. a ty -> int = function | Ty_app -> 0 | Ty_bool -> 1 | Ty_int -> 2 @@ -35,11 +37,12 @@ let discr = function | Ty_bitv n -> 10 + n | Ty_fp n -> 11 + n -let compare t1 t2 = compare (discr t1) (discr t2) +let compare (Ty t1) (Ty t2) = compare (discr t1) (discr t2) let equal t1 t2 = compare t1 t2 = 0 -let pp fmt = function +let pp fmt (Ty ty) = + match ty with | Ty_int -> Fmt.string fmt "int" | Ty_real -> Fmt.string fmt "real" | Ty_bool -> Fmt.string fmt "bool" @@ -56,16 +59,16 @@ let pp fmt = function let string_of_type (ty : t) : string = Fmt.str "%a" pp ty let of_string = function - | "int" -> Ok Ty_int - | "real" -> Ok Ty_real - | "bool" -> Ok Ty_bool - | "str" -> Ok Ty_str - | "list" -> Ok Ty_list - | "app" -> Ok Ty_app - | "unit" -> Ok Ty_unit - | "none" -> Ok Ty_none - | "regexp" -> Ok Ty_regexp - | "RoundingMode" -> Ok Ty_roundingMode + | "int" -> Ok (Ty Ty_int) + | "real" -> Ok (Ty Ty_real) + | "bool" -> Ok (Ty Ty_bool) + | "str" -> Ok (Ty Ty_str) + | "list" -> Ok (Ty Ty_list) + | "app" -> Ok (Ty Ty_app) + | "unit" -> Ok (Ty Ty_unit) + | "none" -> Ok (Ty Ty_none) + | "regexp" -> Ok (Ty Ty_regexp) + | "RoundingMode" -> Ok (Ty Ty_roundingMode) | s -> if String.starts_with ~prefix:"i" s then begin let s = String.sub s 1 (String.length s - 1) in @@ -73,7 +76,7 @@ let of_string = function | None -> Fmt.error_msg "can not parse type %s" s | Some n when n < 0 -> Fmt.error_msg "size of bitvectors must be a positive integer" - | Some n -> Ok (Ty_bitv n) + | Some n -> Ok (Ty (Ty_bitv n)) end else if String.starts_with ~prefix:"f" s then begin let s = String.sub s 1 (String.length s - 1) in @@ -81,11 +84,11 @@ let of_string = function | None -> Fmt.error_msg "can not parse type %s" s | Some n when n < 0 -> Fmt.error_msg "size of fp must be a positive integer" - | Some n -> Ok (Ty_fp n) + | Some n -> Ok (Ty (Ty_fp n)) end else Fmt.error_msg "can not parse type %s" s -let size (ty : t) : int = +let size (Ty ty : t) : int = match ty with | Ty_bitv n | Ty_fp n -> n / 8 | Ty_int | Ty_bool -> 4 @@ -94,40 +97,42 @@ let size (ty : t) : int = assert false module Unop = struct - type t = - | Neg - | Not - | Clz - | Ctz - | Popcnt + type t = U : 'a op -> t + + and _ op = + | Neg : [< `Ty_int | `Ty_real | `Ty_bitv | `Ty_fp | `Ty_none ] op + | Not : [< `Ty_bool | `Ty_int | `Ty_bitv | `Ty_none ] op + | Clz : [ `Ty_bitv ] op + | Ctz : [ `Ty_bitv ] op + | Popcnt : [ `Ty_bitv ] op (* Float *) - | Abs - | Sqrt - | Is_normal - | Is_subnormal - | Is_negative - | Is_positive - | Is_infinite - | Is_nan - | Is_zero - | Ceil - | Floor - | Trunc - | Nearest - | Head - | Tail - | Reverse - | Length + | Abs : [< `Ty_int | `Ty_real | `Ty_fp | `Ty_none ] op + | Sqrt : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Is_normal : [< `Ty_fp | `Ty_none ] op + | Is_subnormal : [< `Ty_fp | `Ty_none ] op + | Is_negative : [< `Ty_fp | `Ty_none ] op + | Is_positive : [< `Ty_fp | `Ty_none ] op + | Is_infinite : [< `Ty_fp | `Ty_none ] op + | Is_nan : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Is_zero : [< `Ty_fp | `Ty_none ] op + | Ceil : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Floor : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Trunc : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Nearest : [< `Ty_real | `Ty_fp | `Ty_none ] op + | Head : [ `Ty_list ] op + | Tail : [ `Ty_list ] op + | Reverse : [ `Ty_list ] op + | Length : [< `Ty_list | `Ty_str ] op (* String *) - | Trim + | Trim : [ `Ty_str ] op (* RegExp *) - | Regexp_star - | Regexp_loop of (int * int) - | Regexp_plus - | Regexp_opt - | Regexp_comp + | Regexp_star : [< `Ty_str | `Ty_regexp ] op + | Regexp_loop : (int * int) -> [< `Ty_str | `Ty_regexp ] op + | Regexp_plus : [< `Ty_str | `Ty_regexp ] op + | Regexp_opt : [< `Ty_str | `Ty_regexp ] op + | Regexp_comp : [< `Ty_str | `Ty_regexp ] op - let equal o1 o2 = + let equal (U o1) (U o2) = match (o1, o2) with | Neg, Neg | Not, Not @@ -166,7 +171,8 @@ module Unop = struct , _ ) -> false - let pp fmt = function + let pp fmt (U op) = + match op with | Neg -> Fmt.string fmt "neg" | Not -> Fmt.string fmt "not" | Clz -> Fmt.string fmt "clz" diff --git a/src/smtml/ty.mli b/src/smtml/ty.mli index 4ca62c13..f992ed5d 100644 --- a/src/smtml/ty.mli +++ b/src/smtml/ty.mli @@ -17,19 +17,23 @@ type _ cast = (** {1 Type Definitions} *) (** The type [t] represents smtml types. *) -type t = - | Ty_app (** Application type. *) - | Ty_bitv of int (** Bitvector type with a specified bit width. *) - | Ty_bool (** Boolean type. *) - | Ty_fp of int (** Floating-point type with a specified bit width. *) - | Ty_int (** Integer type. *) - | Ty_list (** List type. *) - | Ty_none (** None type. *) - | Ty_real (** Real number type. *) - | Ty_str (** String type. *) - | Ty_unit (** Unit type. *) - | Ty_regexp (** Regular expression type. *) - | Ty_roundingMode +type _ ty = + | Ty_app : [> `Ty_app ] ty (** Application type. *) + | Ty_bitv : int -> [> `Ty_bitv ] ty + (** Bitvector type with a specified bit width. *) + | Ty_bool : [> `Ty_bool ] ty (** Boolean type. *) + | Ty_fp : int -> [> `Ty_fp ] ty + (** Floating-point type with a specified bit width. *) + | Ty_int : [> `Ty_int ] ty (** Integer type. *) + | Ty_list : [> `Ty_list ] ty (** List type. *) + | Ty_none : [> `Ty_none ] ty (** None type. *) + | Ty_real : [> `Ty_real ] ty (** Real number type. *) + | Ty_str : [> `Ty_str ] ty (** String type. *) + | Ty_unit : [> `Ty_unit ] ty (** Unit type. *) + | Ty_regexp : [> `Ty_regexp ] ty (** Regular expression type. *) + | Ty_roundingMode : [> `Ty_roundingMode ] ty + +type t = Ty : 'a ty -> t (** {1 Type Comparison} *) @@ -62,38 +66,45 @@ val size : t -> int module Unop : sig (** The type [t] represents unary operations. *) - type t = - | Neg (** Negation. *) - | Not (** Logical NOT. *) - | Clz (** Count leading zeros. *) - | Ctz (** Count trailing zeros. *) - | Popcnt (** Count bits set to 1. *) + type t = U : 'a op -> t + + and _ op = + | Neg : [< `Ty_int | `Ty_real | `Ty_bitv | `Ty_fp | `Ty_none ] op + (** Negation. *) + | Not : [< `Ty_bool | `Ty_int | `Ty_bitv | `Ty_none ] op + (** Logical NOT. *) + | Clz : [ `Ty_bitv ] op (** Count leading zeros. *) + | Ctz : [ `Ty_bitv ] op (** Count trailing zeros. *) + | Popcnt : [ `Ty_bitv ] op (** Count bits set to 1. *) (* Float operations *) - | Abs (** Absolute value. *) - | Sqrt (** Square root. *) - | Is_normal - | Is_subnormal - | Is_negative - | Is_positive - | Is_infinite - | Is_nan (** Check if NaN. *) - | Is_zero - | Ceil (** Ceiling. *) - | Floor (** Floor. *) - | Trunc (** Truncate. *) - | Nearest (** Round to nearest integer. *) - | Head (** Get the head of a list. *) - | Tail (** Get the tail of a list. *) - | Reverse (** Reverse a list. *) - | Length (** Get the length of a list. *) + | Abs : [< `Ty_int | `Ty_real | `Ty_fp | `Ty_none ] op + (** Absolute value. *) + | Sqrt : [< `Ty_real | `Ty_fp | `Ty_none ] op (** Square root. *) + | Is_normal : [< `Ty_fp | `Ty_none ] op + | Is_subnormal : [< `Ty_fp | `Ty_none ] op + | Is_negative : [< `Ty_fp | `Ty_none ] op + | Is_positive : [< `Ty_fp | `Ty_none ] op + | Is_infinite : [< `Ty_fp | `Ty_none ] op + | Is_nan : [< `Ty_real | `Ty_fp | `Ty_none ] op (** Check if NaN. *) + | Is_zero : [< `Ty_fp | `Ty_none ] op + | Ceil : [< `Ty_real | `Ty_fp | `Ty_none ] op (** Ceiling. *) + | Floor : [< `Ty_real | `Ty_fp | `Ty_none ] op (** Floor. *) + | Trunc : [< `Ty_real | `Ty_fp | `Ty_none ] op (** Truncate. *) + | Nearest : [< `Ty_real | `Ty_fp | `Ty_none ] op + (** Round to nearest integer. *) + | Head : [ `Ty_list ] op (** Get the head of a list. *) + | Tail : [ `Ty_list ] op (** Get the tail of a list. *) + | Reverse : [ `Ty_list ] op (** Reverse a list. *) + | Length : [< `Ty_list | `Ty_str ] op (** Get the length of a list. *) (* String operations *) - | Trim (** Trim whitespace (uninterpreted). *) + | Trim : [ `Ty_str ] op (** Trim whitespace (uninterpreted). *) (* Regexp operations *) - | Regexp_star (** Kleene star. *) - | Regexp_loop of (int * int) (** Loop with a range. *) - | Regexp_plus (** Kleene plus. *) - | Regexp_opt (** Optional. *) - | Regexp_comp (** Complement. *) + | Regexp_star : [< `Ty_str | `Ty_regexp ] op (** Kleene star. *) + | Regexp_loop : (int * int) -> [< `Ty_str | `Ty_regexp ] op + (** Loop with a range. *) + | Regexp_plus : [< `Ty_str | `Ty_regexp ] op (** Kleene plus. *) + | Regexp_opt : [< `Ty_str | `Ty_regexp ] op (** Optional. *) + | Regexp_comp : [< `Ty_str | `Ty_regexp ] op (** Complement. *) (** [equal op1 op2] checks if unary operations [op1] and [op2] are equal. *) val equal : t -> t -> bool diff --git a/src/smtml/utils_parse.ml b/src/smtml/utils_parse.ml new file mode 100644 index 00000000..bffa018a --- /dev/null +++ b/src/smtml/utils_parse.ml @@ -0,0 +1 @@ +type unop = U : 'a Ty.ty * 'a Ty.Unop.op -> unop diff --git a/src/smtml/value.ml b/src/smtml/value.ml index d538b708..5ba52544 100644 --- a/src/smtml/value.ml +++ b/src/smtml/value.ml @@ -19,16 +19,16 @@ type t = let type_of (v : t) : Ty.t = match v with - | True | False -> Ty_bool - | Unit -> Ty_unit - | Int _ -> Ty_int - | Real _ -> Ty_real - | Str _ -> Ty_str + | True | False -> Ty Ty_bool + | Unit -> Ty Ty_unit + | Int _ -> Ty Ty_int + | Real _ -> Ty Ty_real + | Str _ -> Ty Ty_str | Num n -> Num.type_of n - | Bitv bv -> Ty_bitv (Bitvector.numbits bv) - | List _ -> Ty_list - | App _ -> Ty_app - | Nothing -> Ty_none + | Bitv bv -> Ty (Ty_bitv (Bitvector.numbits bv)) + | List _ -> Ty Ty_list + | App _ -> Ty Ty_app + | Nothing -> Ty Ty_none let discr = function | True -> 0 @@ -100,12 +100,12 @@ let rec pp fmt = function let to_string (v : t) : string = Fmt.str "%a" pp v -let of_string (cast : Ty.t) v = +let of_string (Ty cast as ty : Ty.t) v = let open Result in match cast with | Ty_bitv m -> Ok (Bitv (Bitvector.make (Z.of_string v) m)) | Ty_fp _ -> - let+ n = Num.of_string cast v in + let+ n = Num.of_string ty v in Num n | Ty_bool -> ( match v with @@ -122,7 +122,7 @@ let of_string (cast : Ty.t) v = | Some n -> Ok (Real n) ) | Ty_str -> Ok (Str v) | Ty_app | Ty_list | Ty_none | Ty_unit | Ty_regexp | Ty_roundingMode -> - Fmt.error_msg "unsupported parsing values of type %a" Ty.pp cast + Fmt.error_msg "unsupported parsing values of type %a" Ty.pp ty let rec to_json (v : t) : Yojson.Basic.t = match v with diff --git a/src/smtml/z3_mappings.default.ml b/src/smtml/z3_mappings.default.ml index b770d6d4..9bfb2f72 100644 --- a/src/smtml/z3_mappings.default.ml +++ b/src/smtml/z3_mappings.default.ml @@ -89,15 +89,15 @@ module M = struct let to_ety sort = match Z3.Sort.get_sort_kind sort with - | Z3enums.INT_SORT -> Ty.Ty_int - | REAL_SORT -> Ty.Ty_real - | BOOL_SORT -> Ty.Ty_bool - | SEQ_SORT -> Ty.Ty_str - | BV_SORT -> Ty.Ty_bitv (Z3.BitVector.get_size sort) + | Z3enums.INT_SORT -> Ty.Ty Ty_int + | REAL_SORT -> Ty.Ty Ty_real + | BOOL_SORT -> Ty.Ty Ty_bool + | SEQ_SORT -> Ty.Ty Ty_str + | BV_SORT -> Ty.Ty (Ty_bitv (Z3.BitVector.get_size sort)) | FLOATING_POINT_SORT -> let ebits = Z3.FloatingPoint.get_ebits ctx sort in let sbits = Z3.FloatingPoint.get_sbits ctx sort in - Ty.Ty_fp (ebits + sbits) + Ty.Ty (Ty_fp (ebits + sbits)) | _ -> assert false end diff --git a/test/integration/test_solver.ml b/test/integration/test_solver.ml index 48aa1915..01c4ef66 100644 --- a/test/integration/test_solver.ml +++ b/test/integration/test_solver.ml @@ -71,7 +71,7 @@ module Make (M : Mappings_intf.S_with_fresh) = struct let open Infix in let module Solver = (val solver_module : Solver_intf.S) in let solver = Solver.create ~logic:LIA () in - let symbol_x = Symbol.("x" @: Ty_int) in + let symbol_x = Symbol.("x" @: Ty Ty_int) in let x = Expr.symbol symbol_x in assert_sat ~f:"test" (Solver.check solver []); @@ -121,8 +121,8 @@ module Make (M : Mappings_intf.S_with_fresh) = struct let module Solver = (val solver_module : Solver_intf.S) in let solver = Solver.create () in assert_sat ~f:"test_lra" - (let x = Expr.symbol Symbol.("x" @: Ty_real) in - let y = Expr.symbol Symbol.("y" @: Ty_real) in + (let x = Expr.symbol Symbol.("x" @: Ty Ty_real) in + let y = Expr.symbol Symbol.("y" @: Ty Ty_real) in let c0 = Expr.relop Ty_bool Eq x y in let c1 = Expr.relop Ty_bool Eq diff --git a/test/test_harness.ml b/test/test_harness.ml index 673f01de..3783c012 100644 --- a/test/test_harness.ml +++ b/test/test_harness.ml @@ -45,7 +45,8 @@ module Infix = struct let app x = value (App (x, [])) - let symbol name ty = symbol (Symbol.make ty name) + let symbol : type a. string -> a Ty.ty -> Expr.t = + fun name ty -> symbol (Symbol.make (Ty ty) name) let ( = ) i1 i2 = relop Ty_bool Eq i1 i2 diff --git a/test/unit/test_eval.ml b/test/unit/test_eval.ml index 871e1c23..30bc20d4 100644 --- a/test/unit/test_eval.ml +++ b/test/unit/test_eval.ml @@ -507,7 +507,7 @@ module Str_test = struct end module Float_test (FXX : sig - val ty : Ty.t + val ty : [ `Ty_fp ] Ty.ty val v : float -> Value.t end) = diff --git a/test/unit/test_expr.ml b/test/unit/test_expr.ml index 618b163a..ff4c01e6 100644 --- a/test/unit/test_expr.ml +++ b/test/unit/test_expr.ml @@ -65,7 +65,7 @@ let test_unop_string _ = let test_unop_bool _ = let ty = Ty.Ty_bool in check (Expr.unop ty Not Expr.Bool.true_) Expr.Bool.false_; - let x = Expr.symbol (Symbol.make ty "x") in + let x = Expr.symbol (Symbol.make (Ty ty) "x") in check (Expr.unop ty Not (Expr.unop ty Not x)) x let test_unop_list _ = @@ -464,11 +464,11 @@ let test_cvtop_i32 _ = let open Infix in check (Expr.cvtop (Ty_bitv 32) TruncSF32 (float32 8.5)) (int32 8l); check (Expr.cvtop (Ty_bitv 32) TruncSF64 (float64 8.5)) (int32 8l); - let x = Expr.symbol (Symbol.make (Ty_bitv 32) "x") in + let x = Expr.symbol (Symbol.make (Ty (Ty_bitv 32)) "x") in let x = Expr.extract x ~high:2 ~low:0 in - assert (Ty.equal (Expr.ty x) (Ty_bitv 16)); + assert (Ty.equal (Expr.ty x) (Ty (Ty_bitv 16))); let x = Expr.cvtop (Ty_bitv 32) (Sign_extend 16) x in - assert (Ty.equal (Expr.ty x) (Ty_bitv 32)) + assert (Ty.equal (Expr.ty x) (Ty (Ty_bitv 32))) let test_cvtop_i64 _ = let open Infix in diff --git a/test/unit/test_model.ml b/test/unit/test_model.ml index 41d99b6a..18b453fb 100644 --- a/test/unit/test_model.ml +++ b/test/unit/test_model.ml @@ -8,16 +8,16 @@ let assert_equal = assert_equal ~cmp:String.equal ~pp_diff let test_to_json _ = - let x = Symbol.make Ty_int "x" in - let y = Symbol.make Ty_real "y" in - let z = Symbol.make Ty_bool "z" in - let u = Symbol.make Ty_str "u" in + let x = Symbol.make (Ty Ty_int) "x" in + let y = Symbol.make (Ty Ty_real) "y" in + let z = Symbol.make (Ty Ty_bool) "z" in + let u = Symbol.make (Ty Ty_str) "u" in let expected = {|{ "model": { "x": { "ty": "int", "value": 1 }, - "u": { "ty": "str", "value": "abc" }, "y": { "ty": "real", "value": 2.0 }, + "u": { "ty": "str", "value": "abc" }, "z": { "ty": "bool", "value": true } } }|} @@ -51,9 +51,9 @@ let test_of_json _ = assert_bool "cannot parse model" (match model with Ok _ -> true | _ -> false) let test_rt_json _ = - let x = Symbol.make (Ty_bitv 32) "x" in - let y = Symbol.make (Ty_bitv 64) "y" in - let z = Symbol.make (Ty_fp 32) "y" in + let x = Symbol.make (Ty (Ty_bitv 32)) "x" in + let y = Symbol.make (Ty (Ty_bitv 64)) "y" in + let z = Symbol.make (Ty (Ty_fp 32)) "y" in let orig_model : Model.t = let tbl = Hashtbl.create 16 in List.iter