Skip to content
157 changes: 157 additions & 0 deletions examples/arithmetic-coding.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
'# Lossless compression
Based on the implementation of [rANS](https://github.com/j-towns/ans-notes/blob/master/rans.py) by James Townsend.

-- prelude additions

def lowWord' (x : Word64) : Word8 = internalCast _ x
-- W8ToI (lowWord' (IToW64 100))

instance Integral Word64
idiv = \x y. %idiv x y
rem = \x y. %irem x y

instance Mul Word64
mul = \x y. %imul x y
one = IToW64 1

instance Eq Word64
(==) = \x y. W8ToB $ %ieq x y

instance Ord Word64
(>) = \x y. W8ToB $ %igt x y
(<) = \x y. W8ToB $ %ilt x y

p_prec = 3
s_prec = 64
t_prec = 32
p_int: Int = %shl 1 p_prec
p_mask: Word64 = (one .<<. p_prec) - one
t_mask: Word64 = (one .<<. t_prec) - one
s_min: Word64 = one .<<. (s_prec - t_prec)
s_max: Word64 = one .<<. s_prec
Alphabet = Fin 26
Interval = Fin p_int
Message = (Word64 & List Word64)

'Utilities

def mod' (x: Word64) (y: Word64) : Word64 = rem (y + rem x y) y

def charToIdx (c: Word8) : Int = W8ToI c - W8ToI 'a'
def idxToChar (i: Int) : Word8 = IToW8 (i + (W8ToI 'a'))

def get_cs (ps: Alphabet=>Word64) : Alphabet=>Word64 =
withState zero \total.
for i. if ps.i > zero
then
currTotal = get total
newTotal = currTotal + ps.i
total := newTotal
currTotal
else zero

def get_ps (str: (Fin l)=>Word8) : Alphabet=>Word64 =
a: Alphabet => Word64 = zero
yieldState a \ref. for i.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can also be computed with a parallel Accum.

i' = (charToIdx str.i)@_
ref!i' := (get ref).i' + one

def get_cs_map (ps: Alphabet=>Word64) : Interval=>Word8 =
init: List Word8 = (AsList 0 [])
map' = yieldState init \map.
for i.
count = W8ToI $ lowWord' ps.i
boundedIter count 0 \_.
map := (get map) <> (AsList 1 [idxToChar (ordinal i)])
Continue
(AsList _ map'') = map'
map = for i:Interval. map''.(unsafeFromOrdinal _ (ordinal i))
map

-- a string to prep the statistics
xs' = "abbccddc"
(AsList l xs) = xs'
ps = get_ps xs
cs = get_cs ps
cs_map = get_cs_map ps

def g (x: Word8) : (Word64 & Word64) =
x_idx = charToIdx x
(cs.(x_idx@_), ps.(x_idx@_))

def f (s': Word64) : (Word8 & (Word64 & Word64)) =
idx = W8ToI $ lowWord' s'
x = cs_map.(idx@_)
(x, g x)

def stack_pop ((AsList l' t'): List Word64) : (Word64 & List Word64) =
l'' = l' - 1
tail = slice t' 1 (Fin l'')
head = t'.(0@_)
(head, (AsList _ tail))

def stack_push (t_top: Word64) (t: List Word64) : (List Word64) =
(AsList 1 [t_top]) <> t

'Coding Interface

def pop ((s, t): Message) : (Message & Word8) =
s_bar = s .&. p_mask
(x, (c, p)) = f s_bar
s' = p * (s .>>. p_prec) + s_bar - c
-- TODO: use a while loop, not a do-while
m' = case s' < s_min of
True ->
yieldState (s', t) \m'.
s'' = fstRef m'
t'' = sndRef m'
while do
(t_top, t') = stack_pop (get t'')
t'' := t'
s'' := ((get s'') .<<. t_prec) + t_top
(get s'') < s_min
False -> (s', t)
(m', x)

def push ((s, t): Message) (x: Word8) : Message =
(c, p) = g x
(s', t') = case s >= (p .<<. (s_prec - p_prec)) of
True ->
yieldState (s, t) \m'.
s' = fstRef m'
t' = sndRef m'
while do
t' := stack_push ((get s') .&. t_mask) (get t')
s' := (get s') .>>. t_prec
get s' >= (p .<<. (s_prec - p_prec))
False -> (s, t)
s'' = ((idiv s' p) .<<. p_prec) + (mod' s' p) + c
(s'', t')


'Demo

-- initialize message
m_init: Message = (s_min, AsList 0 [])
xs_init' = "abbcbcdccdccacbbacccdabbbaccccd"
(AsList l' xs_init) = xs_init'

m' = yieldState m_init \m.
for i:(Fin l').
m := push (get m) xs_init.i

init_args: (Message & List Word8) = (m', AsList 0 [])
(_, xs_decoded) = yieldState init_args \ref.
m = fstRef ref
xs_decoded = sndRef ref
for i:(Fin l').
(m', x) = pop (get m)
m := m'
xs_decoded := (AsList 1 [x]) <> (get xs_decoded)
get ref

(AsList _ xs'') = xs_decoded
:p xs'' == xs_init