Skip to content

Commit fcc5d57

Browse files
committed
feat: add strandlock protocol for messages
1 parent b5f1760 commit fcc5d57

File tree

2 files changed

+96
-70
lines changed

2 files changed

+96
-70
lines changed

core/crypto.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def verify_signature(algorithm: str, message: bytes, signature: bytes, public_ke
6262
with oqs.Signature(algorithm) as verifier:
6363
return verifier.verify(message, signature[:ALGOS_BUFFER_LIMITS[algorithm]["SIGN_LEN"]], public_key[:ALGOS_BUFFER_LIMITS[algorithm]["PK_LEN"]])
6464

65-
def generate_sign_keys(algorithm: str = ML_DSA_87_NAME) Tuple[bytes, bytes]:
65+
def generate_sign_keys(algorithm: str = ML_DSA_87_NAME) -> Tuple[bytes, bytes]:
6666
"""
6767
Generates a new post-quantum signature keypair.
6868
@@ -77,7 +77,7 @@ def generate_sign_keys(algorithm: str = ML_DSA_87_NAME) Tuple[bytes, bytes]:
7777
private_key = signer.export_secret_key()
7878
return private_key, public_key
7979

80-
def otp_encrypt_with_padding(plaintext: bytes, key: bytes) -> bytes:
80+
def otp_encrypt_with_padding(plaintext: bytes, key: bytes) -> Tuple[bytes, bytes]:
8181
"""
8282
Encrypts plaintext using a one-time pad with random or bucket padding.
8383
@@ -98,7 +98,7 @@ def otp_encrypt_with_padding(plaintext: bytes, key: bytes) -> bytes:
9898
if len(plaintext) < OTP_MAX_BUCKET:
9999
pad_len = OTP_MAX_BUCKET - len(plaintext)
100100
else:
101-
pad_len = secrets.randbelow(OTP_MAX_RANDOM_PAD)
101+
pad_len = secrets.randbelow(OTP_MAX_RANDOM_PAD + 1)
102102

103103
padding = secrets.token_bytes(pad_len)
104104

@@ -107,7 +107,7 @@ def otp_encrypt_with_padding(plaintext: bytes, key: bytes) -> bytes:
107107
padded_plaintext = plaintext_length_bytes + plaintext + padding
108108

109109
if len(padded_plaintext) > len(key):
110-
raise ValueError("Plaintext is larger than key!")
110+
raise ValueError("Padded plaintext is larger than key!")
111111

112112
return one_time_pad(padded_plaintext, key)
113113

@@ -122,7 +122,7 @@ def otp_decrypt_with_padding(ciphertext: bytes, key: bytes) -> bytes:
122122
Returns:
123123
Original plaintext bytes without padding.
124124
"""
125-
plaintext_with_padding = one_time_pad(ciphertext, key)
125+
plaintext_with_padding, _ = one_time_pad(ciphertext, key)
126126

127127
plaintext_length = int.from_bytes(plaintext_with_padding[:OTP_SIZE_LENGTH], "big")
128128

@@ -147,7 +147,9 @@ def one_time_pad(plaintext: bytes, key: bytes) -> bytes:
147147
for index, plain_byte in enumerate(plaintext):
148148
key_byte = key[index]
149149
otpd_plaintext += bytes([plain_byte ^ key_byte])
150-
return otpd_plaintext
150+
151+
key = key[len(otpd_plaintext):]
152+
return otpd_plaintext, key
151153

152154
def generate_kem_keys(algorithm: str) -> Tuple[bytes, bytes]:
153155
"""

logic/message.py

Lines changed: 88 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from core.requests import http_request
1414
from logic.storage import save_account_data
1515
from logic.pfs import send_new_ephemeral_keys
16-
from core.trad_crypto import sha3_512
16+
from core.trad_crypto import (
17+
sha3_512,
18+
encrypt_xchacha20poly1305,
19+
decrypt_xchacha20poly1305
20+
)
1721
from core.crypto import (
1822
generate_shared_secrets,
1923
decrypt_shared_secrets,
@@ -24,18 +28,21 @@
2428
otp_decrypt_with_padding
2529
)
2630
from core.constants import (
27-
KEYS_HASH_CHAIN_LEN,
31+
MESSAGE_HASH_CHAIN_LEN,
32+
OTP_MAX_BUCKET,
2833
OTP_PAD_SIZE,
2934
OTP_SIZE_LENGTH,
3035
ML_KEM_1024_NAME,
3136
ML_KEM_1024_CT_LEN,
3237
ML_DSA_87_NAME,
3338
ML_DSA_87_SIGN_LEN,
3439
CLASSIC_MCELIECE_8_F_NAME,
35-
CLASSIC_MCELIECE_8_F_CT_LEN
40+
CLASSIC_MCELIECE_8_F_CT_LEN,
41+
XCHACHA20POLY1305_NONCE_LEN
42+
3643
)
3744
from base64 import b64decode, b64encode
38-
import json
45+
import secrets
3946
import logging
4047

4148
logger = logging.getLogger(__name__)
@@ -65,9 +72,10 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue)
6572

6673
otp_batch_signature = create_signature(ML_DSA_87_NAME, kyber_ciphertext_blob + mceliece_ciphertext_blob, our_lt_private_key)
6774

75+
hash_chain_seed = secrets.token_bytes(MESSAGE_HASH_CHAIN_LEN)
6876
ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305(
6977
our_strand_key,
70-
b"\x00" + otp_batch_signature + kyber_ciphertext_blob + mceliece_ciphertext_blob
78+
b"\x00" + hash_chain_seed + otp_batch_signature + kyber_ciphertext_blob + mceliece_ciphertext_blob
7179
)
7280

7381

@@ -80,12 +88,14 @@ def generate_and_send_pads(user_data, user_data_lock, contact_id: str, ui_queue)
8088
ui_queue.put({"type": "showerror", "title": "Error", "message": "Failed to send our one-time-pads key batch to the server"})
8189
return False
8290

83-
pads = one_time_pad(kyber_shared_secrets, mceliece_shared_secrets)
91+
pads, _ = one_time_pad(kyber_shared_secrets, mceliece_shared_secrets)
8492

8593
# We update & save only at the end, so if request fails, we do not desync our state.
8694
with user_data_lock:
87-
user_data["contacts"][contact_id]["our_pads"]["pads"] = pads[64:]
88-
user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = pads[:64]
95+
user_data["contacts"][contact_id]["our_strand_key"] = pads[:32]
96+
user_data["contacts"][contact_id]["our_pads"]["pads"] = pads[32:]
97+
98+
user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = hash_chain_seed
8999

90100
save_account_data(user_data, user_data_lock)
91101

@@ -111,9 +121,10 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message:
111121
contact_kyber_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][ML_KEM_1024_NAME]
112122
contact_mceliece_public_key = user_data["contacts"][contact_id]["ephemeral_keys"]["contact_public_keys"][CLASSIC_MCELIECE_8_F_NAME]
113123

114-
our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"]
124+
our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"]
115125

116-
126+
127+
117128
if contact_kyber_public_key is None or contact_mceliece_public_key is None:
118129
logger.debug("This shouldn't happen, contact ephemeral keys are not initialized even once yet???")
119130
ui_queue.put({
@@ -141,53 +152,54 @@ def send_message_processor(user_data, user_data_lock, contact_id: str, message:
141152
our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"]
142153

143154

144-
message_encoded = message.encode("utf-8")
145-
next_hash_chain = sha3_512(our_hash_chain + message_encoded)
146-
message_encoded = next_hash_chain + message_encoded
147-
148-
message_otp_padding_length = max(0, OTP_PADDING_LIMIT - OTP_PADDING_LENGTH - len(message_encoded))
149-
150-
if (len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length) > len(our_pads):
151-
logger.info("Your message size (%d) is larger than our pads size (%s), therefore we are generating new pads for you", len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length, len(our_pads))
152-
153-
if not generate_and_send_pads(user_data, user_data_lock, contact_id, ui_queue):
154-
return False
155-
156-
with user_data_lock:
157-
our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"]
158-
our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"]
159-
160-
# We remove old hashchain from message and calculate new next hash in the chain
161-
message_encoded = message_encoded[64:]
162-
next_hash_chain = sha3_512(our_hash_chain + message_encoded)
163-
message_encoded = next_hash_chain + message_encoded
164-
165155

166-
message_otp_pad = our_pads[:len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length]
156+
while True:
157+
message_encoded = message.encode("utf-8")
158+
try:
159+
# We one-time-pad encrypt the message with padding
160+
#
161+
# NOTE: The padding only protects short-messages which are easy to infer what is said based purely on message length
162+
# With messages larger than padding_limit, we assume the message entropy give enough security to make an adversary assumption
163+
# of message context (almost) useless.
164+
#
165+
message_encrypted, new_pads = otp_encrypt_with_padding(message_encoded, our_pads)
166+
logger.debug("Our old pad size is %d and new size after the message is %d", len(our_pads), len(new_pads))
167+
break
168+
except ValueError as e:
169+
logger.debug("Failed to encrypt message to contact (%s) with error: %s", contact_id, str(e))
170+
logger.info("Your message size (%d) when padded, is larger than our pads size (%s), therefore we are generating new pads for you", len(message), len(our_pads))
171+
172+
if not generate_and_send_pads(user_data, user_data_lock, contact_id, ui_queue):
173+
return False
167174

168-
logger.debug("Our pad size is %d and new size after the message is %d", len(our_pads), len(our_pads) - len(message_otp_pad))
175+
with user_data_lock:
176+
our_pads = user_data["contacts"][contact_id]["our_pads"]["pads"]
177+
our_hash_chain = user_data["contacts"][contact_id]["our_pads"]["hash_chain"]
178+
169179

170-
# We one-time-pad encrypt the message with padding
171-
#
172-
# NOTE: The padding only protects short-messages which are easy to infer what is said based purely on message length
173-
# With messages larger than padding_limit, we assume the message entropy give enough security to make an adversary assumption
174-
# of message context (almost) useless.
175-
#
176-
message_encrypted = otp_encrypt_with_padding(message_encoded, message_otp_pad, padding_limit = message_otp_padding_length)
177-
message_encrypted = b64encode(message_encrypted).decode()
178180

179181
# Unlike in other functions, we truncate pads here and compute the next hash chain regardless of request being successful or not
180182
# because a malicious server could make our requests fail to force us to re-use the same pad for our next message
181183
# which would break all of our security
184+
185+
next_hash_chain = sha3_512(our_hash_chain + message_encrypted)
186+
182187
with user_data_lock:
183-
user_data["contacts"][contact_id]["our_pads"]["pads"] = user_data["contacts"][contact_id]["our_pads"]["pads"][len(message_encoded) + OTP_PADDING_LENGTH + message_otp_padding_length:]
188+
user_data["contacts"][contact_id]["our_pads"]["pads"] = user_data["contacts"][contact_id]["our_pads"]["pads"][len(message_encrypted):]
184189
user_data["contacts"][contact_id]["our_pads"]["hash_chain"] = next_hash_chain
185190

191+
our_strand_key = user_data["contacts"][contact_id]["our_strand_key"]
192+
186193
save_account_data(user_data, user_data_lock)
194+
195+
ciphertext_nonce, ciphertext_blob = encrypt_xchacha20poly1305(
196+
our_strand_key,
197+
b"\x01" + next_hash_chain + message_encrypted
198+
)
187199

188200
try:
189-
http_request(f"{server_url}/messages/send_message", "POST", payload = {
190-
"message_encrypted": message_encrypted,
201+
http_request(f"{server_url}/messages/send", "POST", payload = {
202+
"ciphertext_blob": b64encode(ciphertext_nonce + ciphertext_blob).decode(),
191203
"recipient": contact_id
192204
},
193205
auth_token=auth_token
@@ -243,17 +255,18 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
243255
logger.error("Failed to decrypt `ciphertext_blob` from contact (%s) with error: %s", contact_id, str(e))
244256
return
245257

246-
# b"\x00" + otp_batch_signature + kyber_ciphertext_blob + mceliece_ciphertext_blob
247258

248259
if msgs_plaintext[0] == 0:
249260
logger.debug("Received a new OTP pads batch from contact (%s).", contact_id)
250261

251-
if len(msgs_plaintext) != ( (ML_KEM_1024_CT_LEN + CLASSIC_MCELIECE_8_F_CT_LEN) * (OTP_PAD_SIZE // 32)) + ML_DSA_87_SIGN_LEN + 1:
252-
logger.error("Contact (%s) gave us a message request with malformed strand plaintext length (%d)", contact_id, len(msgss_plaintext))
262+
if len(msgs_plaintext) != ( (ML_KEM_1024_CT_LEN + CLASSIC_MCELIECE_8_F_CT_LEN) * (OTP_PAD_SIZE // 32)) + ML_DSA_87_SIGN_LEN + MESSAGE_HASH_CHAIN_LEN + 1:
263+
logger.error("Contact (%s) gave us a otp batch message request with malformed strand plaintext length (%d)", contact_id, len(msgs_plaintext))
253264
return
254265

255-
otp_hashchain_signature = msgs_plaintext[:ML_DSA_87_SIGN_LEN]
256-
otp_hashchain_ciphertext = msgs_plaintext[ML_DSA_87_SIGN_LEN:]
266+
otp_hashchain_signature = msgs_plaintext[1 + MESSAGE_HASH_CHAIN_LEN : MESSAGE_HASH_CHAIN_LEN + ML_DSA_87_SIGN_LEN + 1]
267+
otp_hashchain_ciphertext = msgs_plaintext[ML_DSA_87_SIGN_LEN + MESSAGE_HASH_CHAIN_LEN + 1:]
268+
269+
contact_hash_chain = msgs_plaintext[1 : MESSAGE_HASH_CHAIN_LEN + 1]
257270

258271
try:
259272
valid_signature = verify_signature(ML_DSA_87_NAME, otp_hashchain_ciphertext, otp_hashchain_signature, contact_public_key)
@@ -280,10 +293,10 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
280293
logger.error("Failed to decrypt McEliece's shared_secrets from contact (%s), received error: %s", contact_id, str(e))
281294
return
282295

283-
contact_pads = one_time_pad(contact_kyber_pads, contact_mceliece_pads)
296+
contact_pads, _ = one_time_pad(contact_kyber_pads, contact_mceliece_pads)
284297
contact_strand_key = contact_pads[:32]
285-
contact_hash_chain = contact_pads[32:32 + KEY_HASH_CHAIN_LEN]
286-
contact_pads = contact_pads[32 + KEY_HASH_CHAIN_LEN:]
298+
contact_pads = contact_pads[32:]
299+
287300

288301
with user_data_lock:
289302
user_data["contacts"][contact_id]["contact_pads"]["pads"] = contact_pads
@@ -297,7 +310,7 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
297310

298311
rotation_counter = user_data["contacts"][contact_id]["ephemeral_keys"]["our_keys"][CLASSIC_MCELIECE_8_F_NAME]["rotation_counter"]
299312

300-
313+
301314
logger.debug("Incremented McEliece's rotation_counter by 1 (now is %d) for contact (%s)", rotation_counter, contact_id)
302315

303316
logger.info("Saved contact (%s) new batch of One-Time-Pads, new strand key, and new hash chain seed", contact_id)
@@ -311,13 +324,30 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
311324

312325

313326

314-
elif message["msg_type"] == "new_message":
315-
message_encrypted = b64decode(message["message_encrypted"], validate=True)
327+
elif msgs_plaintext[0] == 1:
328+
logger.debug("Received a new message from contact (%s).", contact_id)
329+
330+
if len(msgs_plaintext) < OTP_MAX_BUCKET + MESSAGE_HASH_CHAIN_LEN + 1:
331+
logger.error("Contact (%s) gave us a message request with malformed strand plaintext length (%d)", contact_id, len(msgs_plaintext))
332+
return
333+
334+
335+
hash_chain = msgs_plaintext[1:MESSAGE_HASH_CHAIN_LEN + 1]
336+
message_encrypted = msgs_plaintext[MESSAGE_HASH_CHAIN_LEN + 1:]
316337

338+
317339
with user_data_lock:
318340
contact_pads = user_data["contacts"][contact_id]["contact_pads"]["pads"]
319341
contact_hash_chain = user_data["contacts"][contact_id]["contact_pads"]["hash_chain"]
320342

343+
344+
next_hash_chain = sha3_512(contact_hash_chain + message_encrypted)
345+
346+
if next_hash_chain != hash_chain:
347+
logger.warning("Message hash chain did not match, this could be a possible replay attack, or a failed tampering attempt. Skipping this message...")
348+
return
349+
350+
321351
if (not contact_pads) or (len(message_encrypted) > len(contact_pads)):
322352
# TODO: Maybe reset our local pads as well?
323353
# I feel like we should do something more when we hit this case, but I am not sure.
@@ -328,15 +358,6 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
328358
# immediately truncate the pads
329359
contact_pads = contact_pads[len(message_encrypted):]
330360

331-
hash_chain = message_decrypted[:64]
332-
message_decrypted = message_decrypted[64:]
333-
334-
next_hash_chain = sha3_512(contact_hash_chain + message_decrypted)
335-
336-
if next_hash_chain != hash_chain:
337-
logger.warning("Message hash chain did not match, this could be a possible replay attack, or a failed tampering attempt. Skipping this message...")
338-
return
339-
340361

341362
# and save the new pads and the hash chain
342363
with user_data_lock:
@@ -358,3 +379,6 @@ def messages_data_handler(user_data: dict, user_data_lock, user_data_copied: dic
358379
"contact_id": contact_id,
359380
"message": message_decoded
360381
})
382+
383+
else:
384+
logger.error("Received unknown message type (%d)", msgs_plaintext[0])

0 commit comments

Comments
 (0)