Skip to content

Commit 812bce1

Browse files
[otbnsim] Cycle-accurate Python model of Trivium/Bivium
This commit introduces a cycle-accurate Python implementation of the Trivium primitive (see `prim_trivium.sv`) for the eventual replacement of the OTBN PRNG. Signed-off-by: Andrea Caforio <andrea.caforio@lowrisc.org>
1 parent 697c06a commit 812bce1

File tree

1 file changed

+378
-0
lines changed

1 file changed

+378
-0
lines changed
Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
# Copyright lowRISC contributors (OpenTitan project).
2+
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import math
6+
from enum import IntEnum
7+
8+
9+
class SeedType(IntEnum):
10+
# This is the regular (standardized for Trivium) method for seeding the
11+
# cipher whereby an 80-bit key and 80-bit IV are injected into the state
12+
# before the initialization rounds are executed.
13+
KEY_IV = 0
14+
# The entire state is seeded once with a seed of the same length. No
15+
# initialization rounds are executed.
16+
STATE_FULL = 1
17+
# Every seed operation fills a chunk of predefined size of the state
18+
# starting with the least significant region until every bit of the state
19+
# has been seeded. The seed operations can be interspersed with update
20+
# invocations such that keystream and seeding can take place concurrently.
21+
STATE_PARTIAL = 2
22+
23+
24+
class CipherType(IntEnum):
25+
# Both Trivium and its simpler variant Bivium can be instantiated. Note
26+
# only Trivium is a standardized cipher while Bivium serves a vehicule to
27+
# study the cryptanalytic properties of this family of ciphers. Both
28+
# primitives can be used to instantiate a PRNG.
29+
TRIVIUM = 0
30+
BIVIUM = 1
31+
32+
33+
def i2b(i: int, n: int) -> list[int]:
34+
"""Convert a little endian integer to a bit array with the LSB at idx 0.
35+
The resulting bit array is padded with 0s if log2(i) < `n` until its size
36+
is `n` bits."""
37+
return [int(b) for b in bin(i)[2:].zfill(n)][::-1]
38+
39+
40+
def b2i(b: list[int]) -> int:
41+
"""Convert a bit array with the LSB at idx 0 to a little endian integer."""
42+
return int("".join(str(d) for d in b[::-1]), 2)
43+
44+
45+
class Trivium:
46+
"""This is a cycle-accurate model of the OpenTitan Trivium primitive.
47+
48+
Instantiating this class corresponds to the cipher state after the reset.
49+
Subsequently, two operations can be scheduled in a clock interval, i.e.,
50+
between calls to `clock`.
51+
52+
- `seed`: Pass a seed to the cipher that will appear in the state after
53+
the next `clock` call. Depending on the seed type different update
54+
sequences are necessary to complete the initialization routines.
55+
56+
- KEY_IV: The entire state needs to be updated 4 times over which
57+
means ceil(4 * `STATE_SIZE` / `OUTPUT_WIDTH`) calls to `update`
58+
and `clock`.
59+
- STATE_FULL: Having called `clock` after `seed` immediately readies
60+
the cipher for the generation of keystream bits.
61+
- STATE_PARTIAL: ceil(`STATE_SIZE` / `PART_SEED_SIZE`) clock
62+
intervals with `seed` are required before the cipher is ready.
63+
There can intervals without `seed` calls. This models stall in
64+
the generation of seed bits from entropy complex.
65+
66+
- `update`: Run the state update function and generate an update state
67+
that replaces the current state at the end of the clock interval,
68+
i.e., after the `clock` call.
69+
"""
70+
71+
TRIVIUM_STATE_SIZE = 288
72+
BIVIUM_STATE_SIZE = 177
73+
74+
PART_SEED_SIZE = 32
75+
TRIVIUM_LAST_PART_SEED_SIZE = 32
76+
BIVIUM_LAST_PART_SEED_SIZE = 17
77+
78+
# Initial state after reset (see `prim_trivium_pkg.sv`).
79+
TRIVIUM_INIT_SEED = i2b(
80+
0x758A442031E1C4616EA343EC153282A30C132B5723C5A4CF4743B3C7C32D580F74F1713A, 288
81+
)
82+
BIVIUM_INIT_SEED = TRIVIUM_INIT_SEED[0:BIVIUM_STATE_SIZE]
83+
84+
def __init__(
85+
self,
86+
cipher_type: CipherType,
87+
seed_type: SeedType,
88+
output_width: int,
89+
init_seed: list[int] = [],
90+
):
91+
"""The cipher is defined by its cipher type (see `CipherType`), its
92+
seed type (see `SeedType`), the output (keystream) width and an
93+
optional initialization seed."""
94+
95+
if cipher_type == CipherType.TRIVIUM:
96+
self.state_size = self.TRIVIUM_STATE_SIZE
97+
self.update_func = self.trivium_update
98+
self.last_part_seed_size = self.TRIVIUM_LAST_PART_SEED_SIZE
99+
100+
if init_seed != []:
101+
if len(init_seed) != TRIVIUM_STATE_SIZE:
102+
raise ValueError(
103+
"trivium init seed must be the same size as the state"
104+
)
105+
self.state = init_seed
106+
else:
107+
self.state = self.TRIVIUM_INIT_SEED[:]
108+
109+
elif cipher_type == CipherType.BIVIUM:
110+
self.state_size = self.BIVIUM_STATE_SIZE
111+
self.update_func = self.bivium_update
112+
self.last_part_seed_size = self.BIVIUM_LAST_PART_SEED_SIZE
113+
114+
if init_seed != []:
115+
if len(init_seed) != BIVIUM_STATE_SIZE:
116+
raise ValueError(
117+
"bivium init seed must be the same size as the state"
118+
)
119+
self.state = init_seed
120+
else:
121+
self.state = self.BIVIUM_INIT_SEED[:]
122+
123+
else:
124+
raise ValueError("unknown cipher type:", cipher_type)
125+
126+
# Depending on cipher and seed type, a different number of seed rounds
127+
# have to be run.
128+
if seed_type == SeedType.KEY_IV:
129+
self.seed_rnd = math.ceil(4 * self.state_size / output_width)
130+
self.seed_ctr = 0
131+
elif seed_type == SeedType.STATE_FULL:
132+
self.seed_rnd = 1
133+
self.seed_ctr = 0
134+
elif seed_type == SeedType.STATE_PARTIAL:
135+
self.seed_rnd = math.ceil(self.state_size / self.PART_SEED_SIZE)
136+
self.seed_ctr = 0
137+
else:
138+
raise ValueError("unknown seed type:", seed_type)
139+
140+
self.cipher_type = cipher_type
141+
self.seed_type = seed_type
142+
self.output_width = output_width
143+
144+
# Scheduled state and seed for the current clock interval.
145+
self.next_state = []
146+
self.next_seed = []
147+
148+
self.ks = [0] * output_width
149+
150+
def update(self) -> None:
151+
"""Run the state update function `OUTPUT_WIDTH`-many times
152+
and schedule the new state to replace the current state at
153+
end of the clock interval."""
154+
155+
if self.next_state != []:
156+
raise Exception("cannot update more than once per clock interval")
157+
158+
self.next_state = self.state[:]
159+
for i in range(self.output_width):
160+
self.ks[i] = self.update_func(self.next_state)
161+
162+
def seed(self, seed) -> None:
163+
"""Schedule a new seed that depending on the seed type will
164+
be injected into the state at the end of the clock
165+
interval."""
166+
167+
if self.seed_type == SeedType.KEY_IV:
168+
assert len(seed) == 160
169+
170+
key = seed[0:80]
171+
iv = seed[80:160]
172+
173+
if self.cipher_type == CipherType.TRIVIUM:
174+
self.next_seed = (
175+
(key + [0] * 13) + (iv + [0] * 4) + ([0] * 108 + [1, 1, 1])
176+
)
177+
else:
178+
self.next_seed = (key + [0] * 13) + (iv + [0] * 4)
179+
180+
elif self.seed_type == SeedType.STATE_FULL:
181+
assert len(seed) == self.state_size
182+
self.next_seed = seed
183+
self.seed_ctr = 0
184+
185+
else:
186+
assert len(seed) == self.PART_SEED_SIZE
187+
self.next_seed = seed
188+
189+
if self.seed_done():
190+
self.seed_ctr = 0
191+
192+
def clock(self):
193+
"""Advance the state by one clock cycle. Depending on the
194+
scheduled state and seed this will alter the current state."""
195+
196+
if self.next_state == [] and self.next_seed == []:
197+
# Do nothing when neither an update nor reseed is scheduled.
198+
return
199+
200+
if self.seed_type == SeedType.KEY_IV:
201+
# Seeding takes precedence over updating.
202+
if self.next_seed != []:
203+
self.state = self.next_seed
204+
205+
elif self.next_state != []:
206+
self.state = self.next_state
207+
if not self.seed_done():
208+
self.seed_ctr += 1
209+
210+
elif self.seed_type == SeedType.STATE_FULL:
211+
# Seeding takes precedence over updating.
212+
if self.next_seed != []:
213+
self.state = self.next_seed
214+
self.seed_ctr += 1
215+
216+
elif self.next_state != []:
217+
self.state = self.next_state
218+
219+
else:
220+
# Update and seeding in the same clock interval is allowed. In this
221+
# case the state is first updated, then partially overwritten with
222+
# the seed bits.
223+
if self.next_state != []:
224+
self.state = self.next_state
225+
if self.next_seed != []:
226+
if self.seed_ctr == self.seed_rnd - 1:
227+
self.state[self.state_size - self.last_part_seed_size :] = (
228+
self.next_seed[: self.last_part_seed_size]
229+
)
230+
else:
231+
self.state[32 * self.seed_ctr : 32 * (self.seed_ctr + 1)] = (
232+
self.next_seed
233+
)
234+
235+
self.seed_ctr += 1
236+
237+
self.next_state = []
238+
self.next_seed = []
239+
240+
def keystream(self):
241+
"""Returns the generated keystream for the current clock
242+
interval."""
243+
return self.ks
244+
245+
def seed_done(self) -> None:
246+
"""Returns true if the seeding procedure has been completed."""
247+
return self.seed_rnd == self.seed_ctr
248+
249+
def trivium_update(self, state):
250+
mul_90_91 = state[90] & state[91]
251+
add_65_92 = state[65] ^ state[92]
252+
253+
mul_174_175 = state[174] & state[175]
254+
add_161_176 = state[161] ^ state[176]
255+
256+
mul_285_286 = state[285] & state[286]
257+
add_242_287 = state[242] ^ state[287]
258+
259+
t0 = state[68] ^ (mul_285_286 ^ add_242_287)
260+
t1 = state[170] ^ (add_65_92 ^ mul_90_91)
261+
t2 = state[263] ^ (mul_174_175 ^ add_161_176)
262+
263+
state[0:93] = [t0] + state[0:92]
264+
state[93:177] = [t1] + state[93:176]
265+
state[177:288] = [t2] + state[177:287]
266+
267+
return add_65_92 ^ add_161_176 ^ add_242_287
268+
269+
def bivium_update(self, state):
270+
mul_90_91 = state[90] & state[91]
271+
add_65_92 = state[65] ^ state[92]
272+
273+
mul_174_175 = state[174] & state[175]
274+
add_161_176 = state[161] ^ state[176]
275+
276+
t0 = state[68] ^ (mul_174_175 ^ add_161_176)
277+
t1 = state[170] ^ (add_65_92 ^ mul_90_91)
278+
279+
state[0:93] = [t0] + state[0:92]
280+
state[93:177] = [t1] + state[93:176]
281+
282+
return add_65_92 ^ add_161_176
283+
284+
285+
# Key-IV seed
286+
287+
# AVR cryptolib: Set 1, vector 0
288+
ref = i2b(
289+
int(
290+
"""
291+
F980FC5474EFE87BB9626ACCCC20FF98
292+
807FCFCE928F6CE0EB21096115F5FBD2
293+
649AF249C24120550175C86414657BBB
294+
0D5420443AF18DAF9C7A0D73FF86EB38""".replace("\n", ""),
295+
16,
296+
),
297+
288,
298+
)
299+
300+
trivium = Trivium(CipherType.TRIVIUM, SeedType.KEY_IV, 64)
301+
302+
key = i2b(0x01000000000000000000, 80)
303+
iv = [0] * 80
304+
305+
trivium.seed(key + iv)
306+
trivium.clock()
307+
308+
while not trivium.seed_done():
309+
trivium.update()
310+
trivium.clock()
311+
312+
assert trivium.seed_done()
313+
314+
print("%072x" % b2i(trivium.state))
315+
316+
keystream = []
317+
for _ in range(8):
318+
trivium.update()
319+
trivium.clock()
320+
keystream.extend(trivium.keystream())
321+
322+
assert keystream == ref
323+
print("%0x" % (b2i(keystream)))
324+
325+
326+
# Full state seed
327+
328+
# Seed corresponds to the state after the init rounds from the first test.
329+
seed = i2b(
330+
0xC7D7C89BCC06725B3D94718106F2A0656422AF1FA457B81F0D2516A9D565893A64C1E50E, 288
331+
)
332+
333+
trivium = Trivium(CipherType.TRIVIUM, SeedType.STATE_FULL, 64)
334+
trivium.seed(seed)
335+
trivium.clock()
336+
337+
assert trivium.seed_done()
338+
339+
keystream = []
340+
for _ in range(8):
341+
trivium.update()
342+
trivium.clock()
343+
keystream.extend(trivium.keystream())
344+
345+
assert keystream == ref
346+
print("%0x" % (b2i(keystream)))
347+
348+
# Partial state seed
349+
350+
# Seed corresponds to the state after the init rounds from the first test.
351+
seed = [
352+
i2b(0x64C1E50E, Trivium.PART_SEED_SIZE),
353+
i2b(0xD565893A, Trivium.PART_SEED_SIZE),
354+
i2b(0x0D2516A9, Trivium.PART_SEED_SIZE),
355+
i2b(0xA457B81F, Trivium.PART_SEED_SIZE),
356+
i2b(0x6422AF1F, Trivium.PART_SEED_SIZE),
357+
i2b(0x06F2A065, Trivium.PART_SEED_SIZE),
358+
i2b(0x3D947181, Trivium.PART_SEED_SIZE),
359+
i2b(0xCC06725B, Trivium.PART_SEED_SIZE),
360+
i2b(0xC7D7C89B, Trivium.PART_SEED_SIZE),
361+
]
362+
363+
trivium = Trivium(CipherType.TRIVIUM, SeedType.STATE_PARTIAL, 64)
364+
365+
for i in range(9):
366+
trivium.seed(seed[i])
367+
trivium.clock()
368+
369+
assert trivium.seed_done()
370+
371+
keystream = []
372+
for _ in range(8):
373+
trivium.update()
374+
trivium.clock()
375+
keystream.extend(trivium.keystream())
376+
377+
assert keystream == ref
378+
print("%0x" % (b2i(keystream)))

0 commit comments

Comments
 (0)