Skip to content

Commit 316cb82

Browse files
committed
Stream: add 'fastexp' example
1 parent 7a92776 commit 316cb82

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
0
2+
0b0
3+
0b0
4+
1
5+
0b1
6+
0b1
7+
2
8+
0b01
9+
0b10
10+
3
11+
0b11
12+
0b11
13+
4
14+
0b001
15+
0b100
16+
123456789
17+
0b101010001011001111011010111
18+
0b111010110111100110100010101
19+
20+
2 ^ 0
21+
1
22+
1
23+
2 ^ 5
24+
32
25+
32
26+
2 ^ 20
27+
1048576
28+
1048576
29+
10 ^ 10
30+
10000000000
31+
10000000000
32+
0 ^ 1
33+
0
34+
0
35+
1 ^ 0
36+
1
37+
1
38+
0 ^ 0
39+
1
40+
1
41+
7 ^ 22
42+
3909821048582988300
43+
3909821048582988300
44+
22 ^ 7
45+
2494357888
46+
2494357888
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import stream
2+
3+
/// Least significant bits
4+
def eachLSB(n: Int): Unit / emit[Bool] = {
5+
if (n <= 0) {
6+
do emit(false)
7+
} else {
8+
var tmp = n
9+
while (tmp > 0) {
10+
do emit(if (tmp.mod(2) == 1) true else false)
11+
tmp = tmp / 2
12+
}
13+
}
14+
}
15+
16+
/// Most significant bits
17+
def eachMSB(n: Int): Unit / emit[Bool] = {
18+
if (n <= 0) {
19+
do emit(false)
20+
} else {
21+
// First, calculate the highest bit position
22+
var highest = 0
23+
var m = n
24+
while (m > 0) {
25+
highest = highest + 1
26+
m = m / 2
27+
}
28+
29+
// Then emit from MSB to LSB
30+
for[Int] { range(0, highest) } { i =>
31+
val index = highest - i
32+
val bit = n.bitwiseShr(index - 1).mod(2) == 1
33+
do emit(bit)
34+
}
35+
}
36+
}
37+
38+
def fastexp(n: Int, k: Int) = product {
39+
stream::zip[Int, Bool] {n.iterate { x => x * x }} {k.eachLSB} {
40+
case res, true => do emit(res)
41+
case res, false => ()
42+
}
43+
}
44+
45+
def main() = {
46+
def prettyBits(bits: List[Bool]): String =
47+
"0b" ++ bits.map { b => if (b) "1" else "0"}.join("")
48+
49+
def testBits(n: Int) = {
50+
println(n)
51+
println(collectList[Bool] {n.eachLSB}.prettyBits)
52+
println(collectList[Bool] {n.eachMSB}.prettyBits)
53+
}
54+
55+
testBits(0)
56+
testBits(1)
57+
testBits(2)
58+
testBits(3)
59+
testBits(4)
60+
testBits(123456789)
61+
62+
println("")
63+
64+
def testExp(n: Int, k: Int) = {
65+
println(show(n) ++ " ^ " ++ show(k))
66+
println(n.toDouble.pow(k))
67+
println(n.fastexp(k))
68+
}
69+
70+
testExp(2, 0)
71+
testExp(2, 5)
72+
testExp(2, 20)
73+
testExp(10, 10)
74+
testExp(0, 1)
75+
testExp(1, 0)
76+
testExp(0, 0)
77+
testExp(7, 22)
78+
testExp(22, 7)
79+
}

0 commit comments

Comments
 (0)