Skip to content

Commit 6d963fe

Browse files
committed
feat:added encryptionManager class for and generate && recoverKey methods && test files
1 parent b222747 commit 6d963fe

File tree

9 files changed

+505
-2
lines changed

9 files changed

+505
-2
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ charset-normalizer==3.1.0
33
idna==3.4
44
requests==2.31.0
55
urllib3==2.0.2
6-
eth-account==0.13.7
6+
eth-account==0.13.7
7+
cryptography

src/lighthouseweb3/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import io
5+
from typing import List, Dict, Any
56
from .functions import (
67
upload as d,
78
deal_status,
@@ -16,7 +17,7 @@
1617
remove_ipns_record as removeIpnsRecord,
1718
create_wallet as createWallet
1819
)
19-
20+
from .functions.encryptionManager import generate, recoverKey
2021

2122
class Lighthouse:
2223
def __init__(self, token: str = ""):
@@ -224,3 +225,19 @@ def getTagged(self, tag: str):
224225
except Exception as e:
225226
raise e
226227

228+
class EncryptionManager:
229+
@staticmethod
230+
def generate(threshold: int, keyCount: int):
231+
try:
232+
return generate.generate(threshold, keyCount)
233+
except Exception as e:
234+
raise e
235+
236+
237+
@staticmethod
238+
def recoverKey(keyShards: List[Dict[str, Any]]):
239+
try:
240+
return recoverKey.recoverKey(keyShards)
241+
except Exception as e:
242+
raise e
243+

src/lighthouseweb3/functions/encryptionManager/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#A 257-bit prime to accommodate 256-bit secrets
2+
PRIME = 2**256 + 297
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import secrets
2+
import logging
3+
from typing import Dict, List, Any
4+
from .config import PRIME
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int:
9+
"""
10+
Evaluate a polynomial with given coefficients at point x.
11+
msk[0] is constant term (the secret), msk[1] is x coefficient, etc.
12+
13+
Args:
14+
coefficients: List of coefficients where coefficients[0] is the constant term
15+
x: Point at which to evaluate the polynomial
16+
prime: Prime number for the finite field
17+
18+
Returns:
19+
The result of the polynomial evaluation modulo prime
20+
"""
21+
result = 0
22+
x_power = 1 # x^0 = 1
23+
24+
for coefficient in coefficients:
25+
result = (result + coefficient * x_power) % prime
26+
x_power = (x_power * x) % prime
27+
28+
return result
29+
30+
async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]:
31+
"""
32+
Generate threshold cryptography key shards using Shamir's Secret Sharing
33+
34+
Args:
35+
threshold: Minimum number of shards needed to reconstruct the secret
36+
key_count: Total number of key shards to generate
37+
38+
Returns:
39+
{
40+
"masterKey": "<master private key hex string>",
41+
"keyShards": [
42+
{
43+
"key": "<shard value hex string>",
44+
"index": "<shard index hex string>"
45+
}
46+
]
47+
}
48+
"""
49+
logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}")
50+
51+
msk=[]
52+
idVec=[]
53+
secVec=[]
54+
55+
if threshold > key_count:
56+
raise ValueError("key_count must be greater than or equal to threshold")
57+
if threshold < 1 or key_count < 1:
58+
raise ValueError("threshold and key_count must be positive integers")
59+
60+
61+
msk = [secrets.randbits(256) for _ in range(threshold)]
62+
master_key = msk[0]
63+
64+
used_ids = set()
65+
66+
for i in range(key_count):
67+
while True:
68+
id_vec = secrets.randbits(32)
69+
if id_vec != 0 and id_vec not in used_ids:
70+
idVec.append(id_vec)
71+
used_ids.add(id_vec)
72+
break
73+
74+
for i in range(key_count):
75+
y = evaluate_polynomial(msk, idVec[i], PRIME)
76+
secVec.append(y)
77+
78+
result = {
79+
"masterKey": hex(master_key),
80+
"keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)]
81+
}
82+
return result
83+
84+
if __name__ == "__main__":
85+
import asyncio
86+
result = asyncio.run(generate(threshold=1, key_count=1))
87+
print(result)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from typing import List, Dict, Any
2+
import logging
3+
from .config import PRIME
4+
5+
logger = logging.getLogger(__name__)
6+
7+
from typing import Tuple
8+
9+
def extended_gcd(a: int, b: int) -> Tuple[int, int, int]:
10+
"""Extended Euclidean algorithm to find modular inverse.
11+
12+
Args:
13+
a: First integer
14+
b: Second integer
15+
16+
Returns:
17+
A tuple (g, x, y) such that a*x + b*y = g = gcd(a, b)
18+
"""
19+
if a == 0:
20+
return b, 0, 1
21+
else:
22+
g, y, x = extended_gcd(b % a, a)
23+
return g, x - (b // a) * y, y
24+
25+
def modinv(a: int, m: int) -> int:
26+
"""Find the modular inverse of a mod m."""
27+
g, x, y = extended_gcd(a, m)
28+
if g != 1:
29+
raise ValueError('Modular inverse does not exist')
30+
else:
31+
return x % m
32+
33+
def lagrange_interpolation(shares: List[Dict[str, str]], prime: int) -> int:
34+
"""
35+
Reconstruct the secret using Lagrange interpolation.
36+
37+
Args:
38+
shares: List of dictionaries with 'key' and 'index' fields
39+
prime: The prime number used in the finite field
40+
41+
Returns:
42+
The reconstructed secret as integer
43+
44+
Raises:
45+
ValueError: If there are duplicate indices
46+
"""
47+
48+
points = []
49+
seen_indices = set()
50+
51+
for i, share in enumerate(shares):
52+
try:
53+
key_str, index_str = validate_share(share, i)
54+
x = int(index_str, 16)
55+
56+
if x in seen_indices:
57+
raise ValueError(f"Duplicate share index found: 0x{x:x}")
58+
seen_indices.add(x)
59+
60+
y = int(key_str, 16)
61+
points.append((x, y))
62+
except ValueError as e:
63+
raise ValueError(f"Invalid share at position {i}: {e}")
64+
65+
66+
secret = 0
67+
68+
for i, (x_i, y_i) in enumerate(points):
69+
# Calculate the Lagrange basis polynomial L_i(0)
70+
# Evaluate at x=0 to get the constant term
71+
numerator = 1
72+
denominator = 1
73+
74+
for j, (x_j, _) in enumerate(points):
75+
if i != j:
76+
numerator = (numerator * (-x_j)) % prime
77+
denominator = (denominator * (x_i - x_j)) % prime
78+
79+
try:
80+
inv_denominator = modinv(denominator, prime)
81+
except ValueError as e:
82+
raise ValueError(f"Error in modular inverse calculation: {e}")
83+
84+
term = (y_i * numerator * inv_denominator) % prime
85+
secret = (secret + term) % prime
86+
87+
return secret
88+
89+
def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]:
90+
"""Validate and normalize a single share.
91+
92+
Args:
93+
share: Dictionary containing 'key' and 'index' fields
94+
index: Position of the share in the input list (for error messages)
95+
96+
Returns:
97+
Tuple of (normalized_key, normalized_index) as strings without '0x' prefix
98+
99+
Raises:
100+
ValueError: If the share is invalid
101+
"""
102+
if not isinstance(share, dict):
103+
raise ValueError(f"Share at index {index} must be a dictionary")
104+
105+
if 'key' not in share or 'index' not in share:
106+
raise ValueError(f"Share at index {index} is missing required fields 'key' or 'index'")
107+
108+
key_str = str(share['key']).strip().lower()
109+
index_str = str(share['index']).strip().lower()
110+
111+
if key_str.startswith('0x'):
112+
key_str = key_str[2:]
113+
if index_str.startswith('0x'):
114+
index_str = index_str[2:]
115+
116+
117+
if not key_str:
118+
raise ValueError(f"Empty key in share at index {index}")
119+
if not all(c in '0123456789abcdef' for c in key_str):
120+
raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string")
121+
122+
if len(key_str) % 2 != 0:
123+
key_str = '0' + key_str
124+
125+
if not index_str:
126+
raise ValueError(f"Empty index in share at index {index}")
127+
if not all(c in '0123456789abcdef' for c in index_str):
128+
raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string")
129+
130+
index_int = int(index_str, 16)
131+
if not (0 <= index_int <= 0xFFFFFFFF):
132+
raise ValueError(f"Index out of range in share at index {index}: must be between 0 and 2^32-1")
133+
134+
return key_str, index_str
135+
136+
137+
async def recoverKey(keyShards: List[Dict[str, str]]) -> Dict[str, Any]:
138+
"""
139+
Recover the master key from a subset of key shares using Lagrange interpolation.
140+
141+
Args:
142+
keyShards: List of dictionaries containing 'key' and 'index' fields
143+
144+
Returns:
145+
{
146+
"masterKey": "<recovered master key hex string>",
147+
"error": "<error message if any>"
148+
}
149+
"""
150+
logger.info(f"Attempting to recover master key from {len(keyShards)} shares")
151+
152+
try:
153+
for i, share in enumerate(keyShards):
154+
validate_share(share, i)
155+
secret = lagrange_interpolation(keyShards, PRIME)
156+
return {
157+
"masterKey": hex(secret),
158+
"error": None
159+
}
160+
except ValueError as e:
161+
logger.error(f"Validation error during key recovery: {str(e)}")
162+
return {
163+
"masterKey": None,
164+
"error": f"Validation error: {str(e)}"
165+
}
166+
except Exception as e:
167+
logger.error(f"Error during key recovery: {str(e)}")
168+
return {
169+
"masterKey": None,
170+
"error": f"Recovery error: {str(e)}"
171+
}

tests/tests_encryptionEngine/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import unittest
2+
import asyncio
3+
import logging
4+
from src.lighthouseweb3 import EncryptionManager
5+
6+
logger = logging.getLogger(__name__)
7+
8+
class TestGenerate(unittest.TestCase):
9+
"""Test cases for the generate module."""
10+
11+
def test_generate_basic(self):
12+
"""Test basic key generation with default parameters."""
13+
async def run_test():
14+
result = await EncryptionManager.generate(threshold=2, keyCount=3)
15+
16+
self.assertIn('masterKey', result)
17+
self.assertIn('keyShards', result)
18+
19+
# Check master key format (hex string with 0x prefix)
20+
self.assertIsInstance(result['masterKey'], str)
21+
self.assertTrue(result['masterKey'].startswith('0x'))
22+
self.assertTrue(all(c in '0123456789abcdef' for c in result['masterKey'][2:]))
23+
24+
# Check key shards
25+
self.assertEqual(len(result['keyShards']), 3)
26+
for shard in result['keyShards']:
27+
self.assertIn('key', shard)
28+
self.assertIn('index', shard)
29+
30+
# Check key format (hex string with 0x prefix)
31+
self.assertTrue(shard['key'].startswith('0x'))
32+
self.assertTrue(all(c in '0123456789abcdef' for c in shard['key'][2:]))
33+
34+
# Check index format (hex string with 0x prefix)
35+
self.assertTrue(shard['index'].startswith('0x'))
36+
self.assertTrue(all(c in '0123456789abcdef' for c in shard['index'][2:]))
37+
38+
return result
39+
40+
return asyncio.run(run_test())
41+
42+
def test_generate_custom_parameters(self):
43+
"""Test key generation with custom parameters."""
44+
async def run_test():
45+
threshold = 3
46+
key_count = 5
47+
48+
result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count)
49+
50+
self.assertEqual(len(result['keyShards']), key_count)
51+
52+
# Check all indices are present and unique
53+
indices = [shard['index'] for shard in result['keyShards']]
54+
self.assertEqual(len(set(indices)), key_count) # All unique
55+
56+
# Verify all indices are valid hex strings with 0x prefix
57+
for index in indices:
58+
self.assertTrue(index.startswith('0x'))
59+
self.assertTrue(all(c in '0123456789abcdef' for c in index[2:]))
60+
61+
return result
62+
63+
return asyncio.run(run_test())
64+
65+
def test_invalid_threshold(self):
66+
"""Test that invalid threshold raises an error."""
67+
async def run_test():
68+
with self.assertRaises(ValueError) as context:
69+
await EncryptionManager.generate(threshold=0, keyCount=3)
70+
self.assertIn("must be positive integers", str(context.exception))
71+
72+
with self.assertRaises(ValueError) as context:
73+
await EncryptionManager.generate(threshold=4, keyCount=3)
74+
self.assertIn("must be greater than or equal to threshold", str(context.exception))
75+
76+
return asyncio.run(run_test())
77+
78+
if __name__ == '__main__':
79+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)