Skip to content

Commit fe8a7e8

Browse files
committed
fix: allow any threshold for Shamir's secret sharing w/ sum
1 parent e936f74 commit fe8a7e8

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/nilql/nilql.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@
3131
_HASH = hashlib.sha512
3232
"""Hash function used for HKDF and matching."""
3333

34-
_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2
35-
"""Minimum number of shares required to reconstruct a Shamir secret."""
36-
3734
def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes:
3835
"""
3936
Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material.
@@ -138,7 +135,7 @@ def _shamirs_eval(poly, x, prime):
138135
def _shamirs_shares(
139136
secret,
140137
total_shares,
141-
minimum_shares=_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION,
138+
minimum_shares,
142139
prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS
143140
):
144141
"""
@@ -162,16 +159,11 @@ def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
162159
"""
163160
Recover the secret value from the supplied share instances.
164161
165-
>>> _shamirs_recover([123])
166-
Traceback (most recent call last):
167-
...
168-
ValueError: at least 2 shares are required
162+
>>> _shamirs_recover([[0, 123]])
163+
123
164+
>>> _shamirs_recover([[0, 123], [1, 123], [2, 123]])
165+
123
169166
"""
170-
if len(shares) < _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION:
171-
raise ValueError(
172-
f'at least {_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares are required'
173-
)
174-
175167
return lagrange(shares, prime)
176168

177169
def _shamirs_add(shares_a, shares_b, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
@@ -804,7 +796,7 @@ def encrypt(
804796
for i in range(len(key['cluster']['nodes']))
805797
]
806798
num_nodes = len(key['cluster']['nodes'])
807-
shares = _shamirs_shares(plaintext, num_nodes)
799+
shares = _shamirs_shares(plaintext, num_nodes, key['threshold'])
808800
for (i, share) in enumerate(shares):
809801
share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
810802

@@ -846,6 +838,15 @@ def decrypt(
846838
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
847839
>>> decrypt(key, encrypt(key, 123))
848840
123
841+
>>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=3)
842+
>>> decrypt(key, encrypt(key, 123)[:-1])
843+
123
844+
>>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=2)
845+
>>> decrypt(key, encrypt(key, 123)[2:])
846+
123
847+
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=1)
848+
>>> decrypt(key, encrypt(key, 123)[1:])
849+
123
849850
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
850851
>>> decrypt(key, encrypt(key, -10))
851852
-10

0 commit comments

Comments
 (0)