3131_HASH = hashlib .sha512
3232"""Hash function used for HKDF and matching."""
3333
34- _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2
35- """Minimum number of shares required to reconstruct a Shamir secret."""
36-
3734def _hkdf_extract (salt : bytes , input_key : bytes ) -> bytes :
3835 """
3936 Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material.
@@ -138,7 +135,7 @@ def _shamirs_eval(poly, x, prime):
138135def _shamirs_shares (
139136 secret ,
140137 total_shares ,
141- minimum_shares = _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION ,
138+ minimum_shares ,
142139 prime = _SECRET_SHARED_SIGNED_INTEGER_MODULUS
143140):
144141 """
@@ -162,16 +159,11 @@ def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
162159 """
163160 Recover the secret value from the supplied share instances.
164161
165- >>> _shamirs_recover([123])
166- Traceback (most recent call last):
167- ...
168- ValueError: at least 2 shares are required
162+ >>> _shamirs_recover([[0, 123] ])
163+ 123
164+ >>> _shamirs_recover([[0, 123], [1, 123], [2, 123]])
165+ 123
169166 """
170- if len (shares ) < _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION :
171- raise ValueError (
172- f'at least { _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION } shares are required'
173- )
174-
175167 return lagrange (shares , prime )
176168
177169def _shamirs_add (shares_a , shares_b , prime = _SECRET_SHARED_SIGNED_INTEGER_MODULUS ):
@@ -804,7 +796,7 @@ def encrypt(
804796 for i in range (len (key ['cluster' ]['nodes' ]))
805797 ]
806798 num_nodes = len (key ['cluster' ]['nodes' ])
807- shares = _shamirs_shares (plaintext , num_nodes )
799+ shares = _shamirs_shares (plaintext , num_nodes , key [ 'threshold' ] )
808800 for (i , share ) in enumerate (shares ):
809801 share [1 ] = (masks [i ] * share [1 ]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
810802
@@ -846,6 +838,15 @@ def decrypt(
846838 >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
847839 >>> decrypt(key, encrypt(key, 123))
848840 123
841+ >>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=3)
842+ >>> decrypt(key, encrypt(key, 123)[:-1])
843+ 123
844+ >>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=2)
845+ >>> decrypt(key, encrypt(key, 123)[2:])
846+ 123
847+ >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=1)
848+ >>> decrypt(key, encrypt(key, 123)[1:])
849+ 123
849850 >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
850851 >>> decrypt(key, encrypt(key, -10))
851852 -10
0 commit comments