|
| 1 | +# Copyright lowRISC contributors (OpenTitan project). |
| 2 | +# Licensed under the Apache License, Version 2.0, see LICENSE for details. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import math |
| 6 | +from enum import IntEnum |
| 7 | + |
| 8 | + |
| 9 | +class SeedType(IntEnum): |
| 10 | + # This is the regular (standardized for Trivium) method for seeding the |
| 11 | + # cipher whereby an 80-bit key and 80-bit IV are injected into the state |
| 12 | + # before the initialization rounds are executed. |
| 13 | + KEY_IV = 0 |
| 14 | + # The entire state is seeded once with a seed of the same length. No |
| 15 | + # initialization rounds are executed. |
| 16 | + STATE_FULL = 1 |
| 17 | + # Every seed operation fills a chunk of predefined size of the state |
| 18 | + # starting with the least significant region until every bit of the state |
| 19 | + # has been seeded. The seed operations can be interspersed with update |
| 20 | + # invocations such that keystream and seeding can take place concurrently. |
| 21 | + STATE_PARTIAL = 2 |
| 22 | + |
| 23 | + |
| 24 | +class CipherType(IntEnum): |
| 25 | + # Both Trivium and its simpler variant Bivium can be instantiated. Note |
| 26 | + # only Trivium is a standardized cipher while Bivium serves a vehicule to |
| 27 | + # study the cryptanalytic properties of this family of ciphers. Both |
| 28 | + # primitives can be used to instantiate a PRNG. |
| 29 | + TRIVIUM = 0 |
| 30 | + BIVIUM = 1 |
| 31 | + |
| 32 | + |
| 33 | +def i2b(i: int, n: int) -> list[int]: |
| 34 | + """Convert a little endian integer to a bit array with the LSB at idx 0. |
| 35 | + The resulting bit array is padded with 0s if log2(i) < `n` until its size |
| 36 | + is `n` bits.""" |
| 37 | + return [int(b) for b in bin(i)[2:].zfill(n)][::-1] |
| 38 | + |
| 39 | + |
| 40 | +def b2i(b: list[int]) -> int: |
| 41 | + """Convert a bit array with the LSB at idx 0 to a little endian integer.""" |
| 42 | + return int("".join(str(d) for d in b[::-1]), 2) |
| 43 | + |
| 44 | + |
| 45 | +class Trivium: |
| 46 | + """This is a cycle-accurate model of the OpenTitan Trivium primitive. |
| 47 | +
|
| 48 | + Instantiating this class corresponds to the cipher state after the reset. |
| 49 | + Subsequently, two operations can be scheduled in a clock interval, i.e., |
| 50 | + between calls to `clock`. |
| 51 | +
|
| 52 | + - `seed`: Pass a seed to the cipher that will appear in the state after |
| 53 | + the next `clock` call. Depending on the seed type different update |
| 54 | + sequences are necessary to complete the initialization routines. |
| 55 | +
|
| 56 | + - KEY_IV: The entire state needs to be updated 4 times over which |
| 57 | + means ceil(4 * `STATE_SIZE` / `OUTPUT_WIDTH`) calls to `update` |
| 58 | + and `clock`. |
| 59 | + - STATE_FULL: Having called `clock` after `seed` immediately readies |
| 60 | + the cipher for the generation of keystream bits. |
| 61 | + - STATE_PARTIAL: ceil(`STATE_SIZE` / `PART_SEED_SIZE`) clock |
| 62 | + intervals with `seed` are required before the cipher is ready. |
| 63 | + There can intervals without `seed` calls. This models stall in |
| 64 | + the generation of seed bits from entropy complex. |
| 65 | +
|
| 66 | + - `update`: Run the state update function and generate an update state |
| 67 | + that replaces the current state at the end of the clock interval, |
| 68 | + i.e., after the `clock` call. |
| 69 | + """ |
| 70 | + |
| 71 | + TRIVIUM_STATE_SIZE = 288 |
| 72 | + BIVIUM_STATE_SIZE = 177 |
| 73 | + |
| 74 | + PART_SEED_SIZE = 32 |
| 75 | + TRIVIUM_LAST_PART_SEED_SIZE = 32 |
| 76 | + BIVIUM_LAST_PART_SEED_SIZE = 17 |
| 77 | + |
| 78 | + # Initial state after reset (see `prim_trivium_pkg.sv`). |
| 79 | + TRIVIUM_INIT_SEED = i2b( |
| 80 | + 0x758A442031E1C4616EA343EC153282A30C132B5723C5A4CF4743B3C7C32D580F74F1713A, 288 |
| 81 | + ) |
| 82 | + BIVIUM_INIT_SEED = TRIVIUM_INIT_SEED[0:BIVIUM_STATE_SIZE] |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + cipher_type: CipherType, |
| 87 | + seed_type: SeedType, |
| 88 | + output_width: int, |
| 89 | + init_seed: list[int] = [], |
| 90 | + ): |
| 91 | + """The cipher is defined by its cipher type (see `CipherType`), its |
| 92 | + seed type (see `SeedType`), the output (keystream) width and an |
| 93 | + optional initialization seed.""" |
| 94 | + |
| 95 | + if cipher_type == CipherType.TRIVIUM: |
| 96 | + self.state_size = self.TRIVIUM_STATE_SIZE |
| 97 | + self.update_func = self.trivium_update |
| 98 | + self.last_part_seed_size = self.TRIVIUM_LAST_PART_SEED_SIZE |
| 99 | + |
| 100 | + if init_seed != []: |
| 101 | + if len(init_seed) != TRIVIUM_STATE_SIZE: |
| 102 | + raise ValueError( |
| 103 | + "trivium init seed must be the same size as the state" |
| 104 | + ) |
| 105 | + self.state = init_seed |
| 106 | + else: |
| 107 | + self.state = self.TRIVIUM_INIT_SEED[:] |
| 108 | + |
| 109 | + elif cipher_type == CipherType.BIVIUM: |
| 110 | + self.state_size = self.BIVIUM_STATE_SIZE |
| 111 | + self.update_func = self.bivium_update |
| 112 | + self.last_part_seed_size = self.BIVIUM_LAST_PART_SEED_SIZE |
| 113 | + |
| 114 | + if init_seed != []: |
| 115 | + if len(init_seed) != BIVIUM_STATE_SIZE: |
| 116 | + raise ValueError( |
| 117 | + "bivium init seed must be the same size as the state" |
| 118 | + ) |
| 119 | + self.state = init_seed |
| 120 | + else: |
| 121 | + self.state = self.BIVIUM_INIT_SEED[:] |
| 122 | + |
| 123 | + else: |
| 124 | + raise ValueError("unknown cipher type:", cipher_type) |
| 125 | + |
| 126 | + # Depending on cipher and seed type, a different number of seed rounds |
| 127 | + # have to be run. |
| 128 | + if seed_type == SeedType.KEY_IV: |
| 129 | + self.seed_rnd = math.ceil(4 * self.state_size / output_width) |
| 130 | + self.seed_ctr = 0 |
| 131 | + elif seed_type == SeedType.STATE_FULL: |
| 132 | + self.seed_rnd = 1 |
| 133 | + self.seed_ctr = 0 |
| 134 | + elif seed_type == SeedType.STATE_PARTIAL: |
| 135 | + self.seed_rnd = math.ceil(self.state_size / self.PART_SEED_SIZE) |
| 136 | + self.seed_ctr = 0 |
| 137 | + else: |
| 138 | + raise ValueError("unknown seed type:", seed_type) |
| 139 | + |
| 140 | + self.cipher_type = cipher_type |
| 141 | + self.seed_type = seed_type |
| 142 | + self.output_width = output_width |
| 143 | + |
| 144 | + # Scheduled state and seed for the current clock interval. |
| 145 | + self.next_state = [] |
| 146 | + self.next_seed = [] |
| 147 | + |
| 148 | + self.ks = [0] * output_width |
| 149 | + |
| 150 | + def update(self) -> None: |
| 151 | + """Run the state update function `OUTPUT_WIDTH`-many times |
| 152 | + and schedule the new state to replace the current state at |
| 153 | + end of the clock interval.""" |
| 154 | + |
| 155 | + if self.next_state != []: |
| 156 | + raise Exception("cannot update more than once per clock interval") |
| 157 | + |
| 158 | + self.next_state = self.state[:] |
| 159 | + for i in range(self.output_width): |
| 160 | + self.ks[i] = self.update_func(self.next_state) |
| 161 | + |
| 162 | + def seed(self, seed) -> None: |
| 163 | + """Schedule a new seed that depending on the seed type will |
| 164 | + be injected into the state at the end of the clock |
| 165 | + interval.""" |
| 166 | + |
| 167 | + if self.seed_type == SeedType.KEY_IV: |
| 168 | + assert len(seed) == 160 |
| 169 | + |
| 170 | + key = seed[0:80] |
| 171 | + iv = seed[80:160] |
| 172 | + |
| 173 | + if self.cipher_type == CipherType.TRIVIUM: |
| 174 | + self.next_seed = ( |
| 175 | + (key + [0] * 13) + (iv + [0] * 4) + ([0] * 108 + [1, 1, 1]) |
| 176 | + ) |
| 177 | + else: |
| 178 | + self.next_seed = (key + [0] * 13) + (iv + [0] * 4) |
| 179 | + |
| 180 | + elif self.seed_type == SeedType.STATE_FULL: |
| 181 | + assert len(seed) == self.state_size |
| 182 | + self.next_seed = seed |
| 183 | + self.seed_ctr = 0 |
| 184 | + |
| 185 | + else: |
| 186 | + assert len(seed) == self.PART_SEED_SIZE |
| 187 | + self.next_seed = seed |
| 188 | + |
| 189 | + if self.seed_done(): |
| 190 | + self.seed_ctr = 0 |
| 191 | + |
| 192 | + def clock(self): |
| 193 | + """Advance the state by one clock cycle. Depending on the |
| 194 | + scheduled state and seed this will alter the current state.""" |
| 195 | + |
| 196 | + if self.next_state == [] and self.next_seed == []: |
| 197 | + # Do nothing when neither an update nor reseed is scheduled. |
| 198 | + return |
| 199 | + |
| 200 | + if self.seed_type == SeedType.KEY_IV: |
| 201 | + # Seeding takes precedence over updating. |
| 202 | + if self.next_seed != []: |
| 203 | + self.state = self.next_seed |
| 204 | + |
| 205 | + elif self.next_state != []: |
| 206 | + self.state = self.next_state |
| 207 | + if not self.seed_done(): |
| 208 | + self.seed_ctr += 1 |
| 209 | + |
| 210 | + elif self.seed_type == SeedType.STATE_FULL: |
| 211 | + # Seeding takes precedence over updating. |
| 212 | + if self.next_seed != []: |
| 213 | + self.state = self.next_seed |
| 214 | + self.seed_ctr += 1 |
| 215 | + |
| 216 | + elif self.next_state != []: |
| 217 | + self.state = self.next_state |
| 218 | + |
| 219 | + else: |
| 220 | + # Update and seeding in the same clock interval is allowed. In this |
| 221 | + # case the state is first updated, then partially overwritten with |
| 222 | + # the seed bits. |
| 223 | + if self.next_state != []: |
| 224 | + self.state = self.next_state |
| 225 | + if self.next_seed != []: |
| 226 | + if self.seed_ctr == self.seed_rnd - 1: |
| 227 | + self.state[self.state_size - self.last_part_seed_size :] = ( |
| 228 | + self.next_seed[: self.last_part_seed_size] |
| 229 | + ) |
| 230 | + else: |
| 231 | + self.state[32 * self.seed_ctr : 32 * (self.seed_ctr + 1)] = ( |
| 232 | + self.next_seed |
| 233 | + ) |
| 234 | + |
| 235 | + self.seed_ctr += 1 |
| 236 | + |
| 237 | + self.next_state = [] |
| 238 | + self.next_seed = [] |
| 239 | + |
| 240 | + def keystream(self): |
| 241 | + """Returns the generated keystream for the current clock |
| 242 | + interval.""" |
| 243 | + return self.ks |
| 244 | + |
| 245 | + def seed_done(self) -> None: |
| 246 | + """Returns true if the seeding procedure has been completed.""" |
| 247 | + return self.seed_rnd == self.seed_ctr |
| 248 | + |
| 249 | + def trivium_update(self, state): |
| 250 | + mul_90_91 = state[90] & state[91] |
| 251 | + add_65_92 = state[65] ^ state[92] |
| 252 | + |
| 253 | + mul_174_175 = state[174] & state[175] |
| 254 | + add_161_176 = state[161] ^ state[176] |
| 255 | + |
| 256 | + mul_285_286 = state[285] & state[286] |
| 257 | + add_242_287 = state[242] ^ state[287] |
| 258 | + |
| 259 | + t0 = state[68] ^ (mul_285_286 ^ add_242_287) |
| 260 | + t1 = state[170] ^ (add_65_92 ^ mul_90_91) |
| 261 | + t2 = state[263] ^ (mul_174_175 ^ add_161_176) |
| 262 | + |
| 263 | + state[0:93] = [t0] + state[0:92] |
| 264 | + state[93:177] = [t1] + state[93:176] |
| 265 | + state[177:288] = [t2] + state[177:287] |
| 266 | + |
| 267 | + return add_65_92 ^ add_161_176 ^ add_242_287 |
| 268 | + |
| 269 | + def bivium_update(self, state): |
| 270 | + mul_90_91 = state[90] & state[91] |
| 271 | + add_65_92 = state[65] ^ state[92] |
| 272 | + |
| 273 | + mul_174_175 = state[174] & state[175] |
| 274 | + add_161_176 = state[161] ^ state[176] |
| 275 | + |
| 276 | + t0 = state[68] ^ (mul_174_175 ^ add_161_176) |
| 277 | + t1 = state[170] ^ (add_65_92 ^ mul_90_91) |
| 278 | + |
| 279 | + state[0:93] = [t0] + state[0:92] |
| 280 | + state[93:177] = [t1] + state[93:176] |
| 281 | + |
| 282 | + return add_65_92 ^ add_161_176 |
| 283 | + |
| 284 | + |
| 285 | +# Key-IV seed |
| 286 | + |
| 287 | +# AVR cryptolib: Set 1, vector 0 |
| 288 | +ref = i2b( |
| 289 | + int( |
| 290 | + """ |
| 291 | +F980FC5474EFE87BB9626ACCCC20FF98 |
| 292 | +807FCFCE928F6CE0EB21096115F5FBD2 |
| 293 | +649AF249C24120550175C86414657BBB |
| 294 | +0D5420443AF18DAF9C7A0D73FF86EB38""".replace("\n", ""), |
| 295 | + 16, |
| 296 | + ), |
| 297 | + 288, |
| 298 | +) |
| 299 | + |
| 300 | +trivium = Trivium(CipherType.TRIVIUM, SeedType.KEY_IV, 64) |
| 301 | + |
| 302 | +key = i2b(0x01000000000000000000, 80) |
| 303 | +iv = [0] * 80 |
| 304 | + |
| 305 | +trivium.seed(key + iv) |
| 306 | +trivium.clock() |
| 307 | + |
| 308 | +while not trivium.seed_done(): |
| 309 | + trivium.update() |
| 310 | + trivium.clock() |
| 311 | + |
| 312 | +assert trivium.seed_done() |
| 313 | + |
| 314 | +print("%072x" % b2i(trivium.state)) |
| 315 | + |
| 316 | +keystream = [] |
| 317 | +for _ in range(8): |
| 318 | + trivium.update() |
| 319 | + trivium.clock() |
| 320 | + keystream.extend(trivium.keystream()) |
| 321 | + |
| 322 | +assert keystream == ref |
| 323 | +print("%0x" % (b2i(keystream))) |
| 324 | + |
| 325 | + |
| 326 | +# Full state seed |
| 327 | + |
| 328 | +# Seed corresponds to the state after the init rounds from the first test. |
| 329 | +seed = i2b( |
| 330 | + 0xC7D7C89BCC06725B3D94718106F2A0656422AF1FA457B81F0D2516A9D565893A64C1E50E, 288 |
| 331 | +) |
| 332 | + |
| 333 | +trivium = Trivium(CipherType.TRIVIUM, SeedType.STATE_FULL, 64) |
| 334 | +trivium.seed(seed) |
| 335 | +trivium.clock() |
| 336 | + |
| 337 | +assert trivium.seed_done() |
| 338 | + |
| 339 | +keystream = [] |
| 340 | +for _ in range(8): |
| 341 | + trivium.update() |
| 342 | + trivium.clock() |
| 343 | + keystream.extend(trivium.keystream()) |
| 344 | + |
| 345 | +assert keystream == ref |
| 346 | +print("%0x" % (b2i(keystream))) |
| 347 | + |
| 348 | +# Partial state seed |
| 349 | + |
| 350 | +# Seed corresponds to the state after the init rounds from the first test. |
| 351 | +seed = [ |
| 352 | + i2b(0x64C1E50E, Trivium.PART_SEED_SIZE), |
| 353 | + i2b(0xD565893A, Trivium.PART_SEED_SIZE), |
| 354 | + i2b(0x0D2516A9, Trivium.PART_SEED_SIZE), |
| 355 | + i2b(0xA457B81F, Trivium.PART_SEED_SIZE), |
| 356 | + i2b(0x6422AF1F, Trivium.PART_SEED_SIZE), |
| 357 | + i2b(0x06F2A065, Trivium.PART_SEED_SIZE), |
| 358 | + i2b(0x3D947181, Trivium.PART_SEED_SIZE), |
| 359 | + i2b(0xCC06725B, Trivium.PART_SEED_SIZE), |
| 360 | + i2b(0xC7D7C89B, Trivium.PART_SEED_SIZE), |
| 361 | +] |
| 362 | + |
| 363 | +trivium = Trivium(CipherType.TRIVIUM, SeedType.STATE_PARTIAL, 64) |
| 364 | + |
| 365 | +for i in range(9): |
| 366 | + trivium.seed(seed[i]) |
| 367 | + trivium.clock() |
| 368 | + |
| 369 | +assert trivium.seed_done() |
| 370 | + |
| 371 | +keystream = [] |
| 372 | +for _ in range(8): |
| 373 | + trivium.update() |
| 374 | + trivium.clock() |
| 375 | + keystream.extend(trivium.keystream()) |
| 376 | + |
| 377 | +assert keystream == ref |
| 378 | +print("%0x" % (b2i(keystream))) |
0 commit comments