Skip to content

Commit a990d71

Browse files
committed
from_seed now takes header string param
1 parent bf05827 commit a990d71

File tree

3 files changed

+65
-57
lines changed

3 files changed

+65
-57
lines changed

src/codex32/codex32.py

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def bech32_hrp_expand(hrp):
7373
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
7474

7575

76-
def ms32_verify_checksum(data, hrp):
76+
def ms32_verify_checksum(hrp, data):
7777
"""Determine long or short checksum and verify it."""
7878
values = bech32_hrp_expand(hrp.lower()) + data
7979
if len(data) >= 96: # See Long codex32 Strings
@@ -85,7 +85,7 @@ def ms32_verify_checksum(data, hrp):
8585
)
8686

8787

88-
def ms32_create_checksum(data, hrp):
88+
def ms32_create_checksum(hrp, data):
8989
"""Determine long or short checksum, create and return it."""
9090
values = bech32_hrp_expand(hrp.lower()) + data
9191
if len(data) > 80: # See Long codex32 Strings
@@ -267,12 +267,17 @@ class ThresholdNotPassed(Codex32Error):
267267
msg = "Threshold not passed"
268268

269269

270-
def bech32_encode(data, hrp):
271-
"""Compute a Bech32 string given HRP and data values."""
270+
def u5_to_bech32(data):
271+
"""Map list of 5-bit integers (0-31) -> bech32 data-part string."""
272272
for i, x in enumerate(data):
273273
if not 0 <= x < 32:
274274
raise InvalidDataValue(f"from 0 to 31 index={i} value={x}")
275-
ret = (hrp + "1" if hrp else "") + "".join(CHARSET[d] for d in data)
275+
return "".join(CHARSET[d] for d in data)
276+
277+
278+
def bech32_encode(hrp, data):
279+
"""Compute a Bech32 string given HRP and data values."""
280+
ret = (hrp + "1" if hrp else "") + u5_to_bech32(data)
276281
if hrp.lower() == hrp:
277282
return ret.lower()
278283
if hrp.upper() == hrp:
@@ -359,44 +364,57 @@ class Codex32String:
359364
"""Class representing a Codex32 string."""
360365

361366
@staticmethod
362-
def parse_header(header_str=""):
367+
def parse_header(s=""):
363368
"""Parse a codex32 header and return its properties."""
364-
hrp = bech32_decode(header_str)[0] if "1" in header_str else header_str
365-
try:
366-
k = int(header_str[len(hrp) + 1 : len(hrp) + 2])
367-
except ValueError as e:
368-
raise InvalidThreshold(f"'{header_str[len(hrp)+1]}' must be a digit") from e
369-
ident = header_str[len(hrp) + 2 : len(hrp) + 6]
369+
hrp, data = s.rsplit("1", 1) if "1" in s else [s, ""]
370+
k = data[0] if data else ""
371+
if k and not k.isdigit():
372+
raise InvalidThreshold(f"'{data[0]}' must be a digit")
373+
ident = data[1:5]
370374
if ident and len(ident) < 4:
371375
raise IdNotLength4(f"{len(ident)}")
372-
share_idx = header_str[len(hrp) + 6 : len(hrp) + 7]
373-
if k == 0 and share_idx.lower() != "s":
376+
share_idx = data[5] if len(data) > 5 else "s" if k == "0" else ""
377+
if k == "0" and share_idx.lower() != "s":
374378
raise InvalidShareIndex(f"'{share_idx}' must be 's' when k=0")
375-
return hrp, k, ident, share_idx
379+
return hrp, k, ident, share_idx, data
376380

377381
def __init__(self, s=""):
378-
self.s = s
379-
self.hrp, data = bech32_decode(s)
380-
_, data_part = s.rsplit("1", 1)
381-
if 44 < len(data_part) < 94:
382+
self.hrp, self.k, self.ident, self.share_idx, data = self.parse_header(s)
383+
if 44 < len(data) < 94:
382384
checksum_len = 13
383-
elif 95 < len(data_part) < 125:
385+
elif 95 < len(data) < 125:
384386
checksum_len = 15
385387
else:
386-
raise InvalidLength(f"{len(data_part)} must be 45-93 or 96-124")
387-
header_str = bech32_encode(data[:6], self.hrp)
388-
_, self.k, self.ident, self.share_idx = self.parse_header(header_str)
389-
self.payload = data_part[6:-checksum_len]
388+
raise InvalidLength(f"{len(data)} must be 45-93 or 96-124")
389+
self.payload = data[6:-checksum_len]
390390
incomplete_group = (len(self.payload) * 5) % 8
391391
if incomplete_group > 4:
392392
raise IncompleteGroup(str(incomplete_group))
393-
if not ms32_verify_checksum(data, self.hrp):
394-
raise InvalidChecksum(f"string={self.s}")
393+
if not ms32_verify_checksum(*bech32_decode(s)):
394+
raise InvalidChecksum(f"string={s}")
395+
396+
@property
397+
def _unchecksummed_s(self):
398+
"""Return the codex32 string without the checksum."""
399+
return self.hrp + "1" + self.k + self.ident + self.share_idx + self.payload
400+
401+
@property
402+
def checksum(self):
403+
"""Calculate the checksum part of the Codex32 string."""
404+
ret = u5_to_bech32(ms32_create_checksum(*bech32_decode(self._unchecksummed_s)))
405+
return ret if self.hrp.islower() else ret.upper()
406+
407+
@property
408+
def data_part_chars(self):
409+
"""Return the data part characters of the Codex32 string."""
410+
return self.k + self.ident + self.share_idx + self.payload + self.checksum
411+
412+
@property
413+
def s(self):
414+
return self.hrp + "1" + self.data_part_chars
395415

396416
def __str__(self):
397-
return self.from_unchecksummed_string(
398-
self.hrp + "1" + str(self.k) + self.ident + self.share_idx + self.payload
399-
).s
417+
return self.s
400418

401419
def __eq__(self, other):
402420
if not isinstance(other, Codex32String):
@@ -406,12 +424,6 @@ def __eq__(self, other):
406424
def __hash__(self):
407425
return hash(self.s)
408426

409-
@property
410-
def checksum(self):
411-
"""Calculate the checksum part of the Codex32 string."""
412-
data = bech32_to_u5(str(self.k) + self.ident + self.share_idx + self.payload)
413-
return bech32_encode(ms32_create_checksum(data, self.hrp), "")
414-
415427
@property
416428
def data(self):
417429
"""Return the payload data bytes."""
@@ -420,8 +432,7 @@ def data(self):
420432
@classmethod
421433
def from_unchecksummed_string(cls, s):
422434
"""Create Codex32String from unchecksummed string."""
423-
hrp, data = bech32_decode(s)
424-
return cls(bech32_encode(data + ms32_create_checksum(data, hrp), hrp))
435+
return cls(s + u5_to_bech32(ms32_create_checksum(*bech32_decode(s))))
425436

426437
@classmethod
427438
def from_string(cls, s, hrp="ms"):
@@ -437,7 +448,7 @@ def interpolate_at(cls, shares, target="s"):
437448
indices = []
438449
ms32_shares = []
439450
s0_parts = shares[0]
440-
if s0_parts.k > len(shares):
451+
if int(s0_parts.k) > len(shares):
441452
raise ThresholdNotPassed(f"threshold={s0_parts.k}, n_shares={len(shares)}")
442453
for share in shares:
443454
if len(shares[0].s) != len(share.s):
@@ -456,24 +467,20 @@ def interpolate_at(cls, shares, target="s"):
456467
if indices[i] == target:
457468
return share
458469
result = ms32_interpolate(ms32_shares, CHARSET.index(target.lower()))
459-
ret = bech32_encode(result, s0_parts.hrp)
470+
ret = bech32_encode(s0_parts.hrp, result)
460471
return cls(ret)
461472

462473
@classmethod
463-
# pylint: disable=too-many-positional-arguments,too-many-arguments
464-
def from_seed(cls, data, ident="", hrp="ms", k=0, share_idx="s", pad_val=None):
465-
"""Create Codex32String from seed bytes."""
474+
def from_seed(cls, data, header="ms10", pad_val=None):
475+
"""Create Codex32String from seed bytes and header."""
476+
hrp, k, ident, share_idx, _ = cls.parse_header(header)
466477
if 16 > len(data) or len(data) > 64:
467-
raise InvalidLength(f"{len(data)} bytes data MUST be 16 to 64 bytes")
468-
if not ident and share_idx == "s":
478+
raise InvalidLength(f"{len(data)} bytes. Data must be 16 to 64 bytes")
479+
share_idx = "s" if not share_idx else share_idx
480+
if not ident:
469481
bip32 = BIP32.from_seed(data)
470-
ident = bech32_encode(convertbits(bip32.get_fingerprint(), 8, 5), "")[:4]
471-
if len(ident) != 4:
472-
raise IdNotLength4(f"{len(ident)}")
473-
if not (1 < k <= 9 or k == 0):
474-
raise InvalidThresholdN(str(k))
482+
ident += u5_to_bech32(convertbits(bip32.get_fingerprint(), 8, 5))[:4]
475483
payload = convertbits(data, 8, 5, pad_val=pad_val)
476-
header = bech32_to_u5(str(k) + ident + share_idx)
477-
combined = header + payload
478-
ret = bech32_encode(combined + ms32_create_checksum(combined, hrp), hrp)
479-
return cls(ret)
484+
k = "0" if not k else k
485+
header = bech32_to_u5(k + ident + share_idx)
486+
return cls.from_unchecksummed_string(bech32_encode(hrp, header + payload))

tests/data/bip93_vectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"secret_s": "ms10testsxxxxxxxxxxxxxxxxxxxxxxxxxx4nzvca9cmczlw",
66
"secret_hex": "318c6318c6318c6318c6318c6318c631",
77
"hrp": "ms",
8-
"k": 0,
8+
"k": "0",
99
"identifier": "test",
1010
"share_index": "s",
1111
"payload": "xxxxxxxxxxxxxxxxxxxxxxxxxx",

tests/test_bip93.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_derive_and_recover():
4848
a = Codex32String(VECTOR_2["share_A"])
4949
c = Codex32String(VECTOR_2["share_C"])
5050
# interpolation target is 'D' (uppercase as inputs are uppercase)
51+
print(a.s, c.s)
5152
d = Codex32String.interpolate_at([a, c], "D")
5253
assert str(d) == VECTOR_2["derived_D"]
5354
s = Codex32String.interpolate_at([a, c], "S")
@@ -60,7 +61,7 @@ def test_from_seed_and_interpolate_3_of_5():
6061
seed = bytes.fromhex(VECTOR_3["secret_hex"])
6162
a = Codex32String(VECTOR_3["share_a"])
6263
c = Codex32String(VECTOR_3["share_c"])
63-
s = Codex32String.from_seed(seed, a.ident, a.hrp, a.k, pad_val=0)
64+
s = Codex32String.from_seed(seed, a.hrp + "1" + a.k + a.ident, pad_val=0)
6465
assert str(s) == VECTOR_3["secret_s"]
6566
d = Codex32String.interpolate_at([s, a, c], "d")
6667
e = Codex32String.interpolate_at([s, a, c], "e")
@@ -69,15 +70,15 @@ def test_from_seed_and_interpolate_3_of_5():
6970
assert str(e) == VECTOR_3["derived_e"]
7071
assert str(f) == VECTOR_3["derived_f"]
7172
for pad_val in range(0b11 + 1):
72-
s = Codex32String.from_seed(seed, a.ident, a.hrp, a.k, pad_val=pad_val)
73+
s = Codex32String.from_seed(seed, a.hrp + "1" + a.k + a.ident, pad_val=pad_val)
7374
assert str(s) == VECTOR_3["secret_s_alternates"][pad_val]
7475

7576

7677
def test_from_seed_and_alternates():
7778
"""Test Vector 4: encode secret share from seed"""
7879
seed = bytes.fromhex(VECTOR_4["secret_hex"])
7980
for pad_val in range(0b1111 + 1):
80-
s = Codex32String.from_seed(seed, ident="leet", pad_val=pad_val)
81+
s = Codex32String.from_seed(seed, header="ms10leet", pad_val=pad_val)
8182
assert str(s) == VECTOR_4["secret_s_alternates"][pad_val]
8283
assert s.data == list(seed) or s.data == seed
8384
# confirm all 16 encodings decode to same master data

0 commit comments

Comments
 (0)