Skip to content

Commit 971f33f

Browse files
committed
chore: cleanup, add sqrt impl
1 parent d35772b commit 971f33f

File tree

4 files changed

+33
-16
lines changed

4 files changed

+33
-16
lines changed

Fp.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import Fp.Division
55
import Fp.FMA
66
import Fp.Multiplication
77
import Fp.Negation
8-
import Fp.Remainder
8+
-- import Fp.Remainder
99
import Fp.Rounding
1010
import Fp.Sqrt
1111
import Fp.Subtraction

Fp/Remainder.lean

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@ import Fp.Rounding
33
import Fp.Division
44
import Fp.UnpackedRound
55

6-
def UnpackedFloat.remainder (a b : UnpackedFloat e s) : UnpackedFloat e s :=
7-
let rnd := a.div b
8-
let rne := rnd.round .RNE
9-
sorry
10-
11-
def EUnpackedFloat.remainder
12-
(a b : EUnpackedFloat (exponentWidth e s) (s + 1)) (rm : RoundingMode) :
13-
EUnpackedFloat (exponentWidth e s) (s + 1) :=
14-
if a.isNaN || b.isNaN || a.isInfinite || b.isZero then .mkNaN
15-
else if a.isZero || b.isInfinite then a
16-
else
17-
let uf := UnpackedFloat.remainder a.num b.num
18-
uf.round rm
196

207
@[bv_normalize]
218
def remainderFixed (a b : PackedFloat e s) : PackedFloat e s :=
@@ -60,4 +47,4 @@ def remainderFixed (a b : PackedFloat e s) : PackedFloat e s :=
6047
EFixedPoint.round e s .RTZ result
6148

6249
/-- info: { sign := -, ex := 0x00#5, sig := 0x1#2 } -/
63-
#guard_msgs in #eval remainder (PackedFloat.ofBits 5 2 0b00000011) (PackedFloat.ofBits 5 2 0b00000100)
50+
#guard_msgs in #eval remainderFixed (PackedFloat.ofBits 5 2 0b00000011) (PackedFloat.ofBits 5 2 0b00000100)

Fp/Tests/ExhaustiveEnumerationRat.lean

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import Fp.Basic
55
import Fp.Addition
66
import Fp.Multiplication
77
import Fp.Division
8+
import Fp.Sqrt
89
import Fp.Tests.PackedFloatEnumeration
910

1011
namespace Fp
@@ -172,5 +173,32 @@ def testDiv (ein sin : Nat) : IO (UnpackedRatTestSummary ein sin) := do
172173
results := results.push result
173174
return summarizeUnpackedRatTestResults "div" results
174175

176+
/--
177+
Compute a rational approximation of sqrt(r) using Newton's method.
178+
-/
179+
def ratSqrt (r : Rat) (iters : Nat := 20) : Rat :=
180+
if r ≤ 0 then 0
181+
else Id.run do
182+
let mut x : Rat := r
183+
for _ in [:iters] do
184+
x := (x + r / x) / 2
185+
return x
186+
187+
/--
188+
Produce the results from taking the square root of all packed floats
189+
for the given exponent and significand sizes.
190+
-/
191+
def testSqrt (ein sin : Nat) : IO (UnpackedRatTestSummary ein sin) := do
192+
let mut results : Array (UnpackedRatTestResult ein sin) := #[]
193+
let enum : PackedFloatEnumeration ein sin := PackedFloatEnumeration.mk ein sin
194+
for (pf, r) in enum.enumeration do
195+
if r ≤ 0 then continue
196+
let uf := pf.unpack.num
197+
let produced := UnpackedFloat.sqrt uf
198+
let expected := ratSqrt r
199+
let result := UnpackedRatTestResult.mk (Array.mk [(pf, r)]) produced expected
200+
results := results.push result
201+
return summarizeUnpackedRatTestResults "sqrt" results
202+
175203
end ExhaustiveEnumerationRat
176204
end Fp

Main.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Fp
2+
import Fp.Remainder
23
import Fp.Tests
34

45
structure OpResult where
@@ -141,7 +142,7 @@ def test_rem (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult :=
141142
{
142143
oper := "rem"
143144
mode := m
144-
result := [a, b, f.h.mp (remainder a' b').toBits'].map toDigits
145+
result := [a, b, f.h.mp (remainderFixed a' b').toBits'].map toDigits
145146
}
146147

147148
def test_binop (f : RoundingMode → BitVec 8 → BitVec 8 → OpResult) : Thunk (List OpResult) :=
@@ -275,6 +276,7 @@ def get_long_operation (args : List String) : IO Unit := do
275276
| ["addRat"] => IO.println (← Fp.ExhaustiveEnumerationRat.testAdd 3 4).toFormat
276277
| ["mulRat"] => IO.println (← Fp.ExhaustiveEnumerationRat.testMul 3 4).toFormat
277278
| ["divRat"] => IO.println (← Fp.ExhaustiveEnumerationRat.testDiv 3 4).toFormat
279+
| ["sqrtRat"] => IO.println (← Fp.ExhaustiveEnumerationRat.testSqrt 3 4).toFormat
278280
| ["roundCircuitAgainstSmtLib"] =>
279281
test_roundCircuitAgainstSmtlib (ein := 3) (sin := 6) (eout := 3) (sout := 4)
280282
test_roundCircuitAgainstSmtlib (ein := 3) (sin := 6) (eout := 3) (sout := 4)

0 commit comments

Comments
 (0)