From 2c16ed7f39f362d0ab9e670d367e223bb1471fde Mon Sep 17 00:00:00 2001 From: Arthur Carcano Date: Mon, 7 Jul 2025 17:59:26 +0200 Subject: [PATCH] Change implementation of clz and ctz --- src/smtml/mappings.ml | 53 ++++++++++++++++++------------------- test/unit/test_bitvector.ml | 8 +++--- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/smtml/mappings.ml b/src/smtml/mappings.ml index c35958ba..26bef59d 100644 --- a/src/smtml/mappings.ml +++ b/src/smtml/mappings.ml @@ -303,8 +303,6 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct module Ixx : sig val of_int : int -> elt - - val shift_left : elt -> int -> elt end end @@ -312,31 +310,35 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct open M include B - (* Stolen from @krtab in OCamlPro/owi#195 *) - let clz n = - let rec loop (lb : int) (ub : int) = - if ub = lb + 1 then v @@ Ixx.of_int (bitwidth - ub) + let clo n = + let rec loop (next : int) expr = + if Prelude.Int.equal next 0 then expr else - let mid = (lb + ub) / 2 in - let pow_two_mid = v Ixx.(shift_left (of_int 1) mid) in - ite (Bitv.lt_u n pow_two_mid) (loop lb mid) (loop mid ub) + let shift = next in + let shifted = Bitv.lshr n (v @@ Ixx.of_int shift) in + let bit = Bitv.rem_u shifted (v @@ Ixx.of_int 2) in + let expr = Bitv.add bit (Bitv.mul bit expr) in + let next = pred next in + loop next expr in - ite - (eq n (v @@ Ixx.of_int 0)) - (v @@ Ixx.of_int bitwidth) - (loop 0 bitwidth) - - (* Stolen from @krtab in OCamlPro/owi #195 *) - let ctz n = - let zero = v (Ixx.of_int 0) in - let rec loop (lb : int) (ub : int) = - if ub = lb + 1 then v (Ixx.of_int lb) + loop bitwidth (v @@ Ixx.of_int 1) + + let clz n = clo @@ Bitv.lognot n + + let cto n = + let rec loop (next : int) expr = + if Prelude.Int.equal next bitwidth then expr else - let mid = (lb + ub) / 2 in - let pow_two_mid = v Ixx.(shift_left (of_int 1) mid) in - M.ite (eq (Bitv.rem n pow_two_mid) zero) (loop mid ub) (loop lb mid) + let shift = bitwidth - next - 1 in + let shifted = Bitv.lshr n (v @@ Ixx.of_int shift) in + let bit = Bitv.rem_u shifted (v @@ Ixx.of_int 2) in + let expr = Bitv.add bit (Bitv.mul bit expr) in + let next = succ next in + loop next expr in - ite (eq n zero) (v @@ Ixx.of_int bitwidth) (loop 0 bitwidth) + loop 0 (v @@ Ixx.of_int 1) + + let ctz n = cto @@ Bitv.lognot n let popcnt n = let rec loop (next : int) count = @@ -422,8 +424,6 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct module Ixx = struct let of_int i = i [@@inline] - - let shift_left v i = v lsl i [@@inline] end end) @@ -810,8 +810,7 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct else ( assert (Z.equal b Z.zero); Value.False ) - | Ty_bitv m -> - Value.Bitv (Bitvector.make (M.Interp.to_bitv v m) m) + | Ty_bitv m -> Value.Bitv (Bitvector.make (M.Interp.to_bitv v m) m) | Ty_fp 32 -> let float = M.Interp.to_float v 8 24 in Value.Num (F32 (Int32.bits_of_float float)) diff --git a/test/unit/test_bitvector.ml b/test/unit/test_bitvector.ml index 426cd542..4ac51428 100644 --- a/test/unit/test_bitvector.ml +++ b/test/unit/test_bitvector.ml @@ -33,12 +33,12 @@ let test_neg _ = check (neg bv) (make (z (-5)) 8) let test_clz _ = - let bv = make (z 1) 8 in - check (clz bv) (make (z 7) 8) + let bv = make (z 2) 8 in + check (clz bv) (make (z 6) 8) let test_ctz _ = - let bv = make (z 128) 8 in - check (ctz bv) (make (z 7) 8) + let bv = make (z 64) 8 in + check (ctz bv) (make (z 6) 8) let test_popcnt _ = let bv = make (z 0b1010_1010) 8 in