Skip to content

Commit 76e657d

Browse files
committed
finished correcting tests for hmac library
1 parent 99694a9 commit 76e657d

File tree

5 files changed

+107
-92
lines changed

5 files changed

+107
-92
lines changed

circuitpython-workspaces/flight-software/src/pysquared/cdh.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import traceback
2222

2323
import microcontroller
24+
from circuitpython_hmac import HMAC
2425

2526
from .config.config import Config
2627
from .hardware.radio.packetizer.packet_manager import PacketManager
@@ -35,6 +36,7 @@ class CommandDataHandler:
3536
command_reset: str = "reset"
3637
command_change_radio_modulation: str = "change_radio_modulation"
3738
command_send_joke: str = "send_joke"
39+
command_get_counter: str = "get_counter"
3840

3941
oscar_password: str = "Hello World!" # Default password for OSCAR commands
4042

@@ -45,6 +47,7 @@ def __init__(
4547
packet_manager: PacketManager,
4648
last_command_counter: Counter16 = 1,
4749
send_delay: float = 0.2,
50+
hmac_class=HMAC,
4851
) -> None:
4952
"""Initializes the CommandDataHandler.
5053
@@ -60,7 +63,7 @@ def __init__(
6063
self._packet_manager: PacketManager = packet_manager
6164
self._send_delay: float = send_delay
6265
self._hmac_authenticator: HMACAuthenticator = HMACAuthenticator(
63-
config.hmac_secret
66+
config.hmac_secret, hmac_class=hmac_class
6467
)
6568
self._last_command_counter: Counter16 = last_command_counter
6669

@@ -78,6 +81,7 @@ def listen_for_commands(self, timeout: int) -> None:
7881

7982
try:
8083
json_str = json_bytes.decode("utf-8")
84+
print("Got a message!!!")
8185

8286
msg: dict[str, str] = json.loads(json_str)
8387

@@ -104,8 +108,19 @@ def listen_for_commands(self, timeout: int) -> None:
104108
self.oscar_command(cmd, args)
105109
return
106110

111+
# If message has command field, get the command
112+
cmd = msg.get("command")
113+
114+
print("Command is", cmd)
115+
116+
if cmd is not None and cmd == self.command_get_counter:
117+
self.send_counter()
118+
119+
print("got command")
120+
107121
# HMAC-based authentication (required for non-OSCAR commands)
108-
hmac_value = msg.get("hmac")
122+
hmac_value = str(msg.get("hmac"))
123+
print("TYPETYPETYEPE", type(hmac_value))
109124
counter_raw = msg.get("counter")
110125

111126
# Require HMAC authentication
@@ -116,6 +131,7 @@ def listen_for_commands(self, timeout: int) -> None:
116131
)
117132
return
118133

134+
print("counter and hmac not None")
119135
# Use HMAC authentication
120136
# Convert counter to int
121137
try:
@@ -127,6 +143,8 @@ def listen_for_commands(self, timeout: int) -> None:
127143
)
128144
return
129145

146+
print("counter is an int")
147+
130148
# Validate counter is within 16-bit range
131149
if counter < 0 or counter > 0xFFFF:
132150
self._log.debug(
@@ -135,14 +153,24 @@ def listen_for_commands(self, timeout: int) -> None:
135153
)
136154
return
137155

156+
print("counter is validated")
157+
158+
print("the message actually is", msg)
159+
138160
# Extract message without HMAC for verification
139161
msg_without_hmac = {k: v for k, v in msg.items() if k != "hmac"}
140162
message_str = json.dumps(msg_without_hmac, separators=(",", ":"))
141163

164+
print("message after loop", msg)
165+
166+
print("\nmessage is ", message_str)
167+
print("\ncounter is ", counter)
168+
142169
# Verify HMAC
143170
if not self._hmac_authenticator.verify_hmac(
144171
message_str, counter, hmac_value
145172
):
173+
print("Invalid HMAC inmessgae", msg)
146174
self._log.debug(
147175
"Invalid HMAC in message",
148176
msg=msg,
@@ -151,11 +179,13 @@ def listen_for_commands(self, timeout: int) -> None:
151179

152180
# Prevent replay attacks with wraparound handling
153181
last_valid = self._last_command_counter.get()
182+
print(last_valid, "last valid")
154183

155184
# Check if counter is valid considering 16-bit wraparound
156185
# Accept if counter is greater, or if wraparound occurred
157186
# (counter is much smaller, indicating it wrapped around)
158187
counter_diff = (counter - last_valid) & 0xFFFF
188+
print("counter diff", counter_diff)
159189

160190
# Valid if counter is within forward window (1 to 32768)
161191
# This allows for wraparound while preventing replay attacks
@@ -179,8 +209,6 @@ def listen_for_commands(self, timeout: int) -> None:
179209
)
180210
return
181211

182-
# If message has command field, execute the command
183-
cmd = msg.get("command")
184212
if cmd is None:
185213
self._log.warning("No command found in message", msg=msg)
186214
self._packet_manager.send(
@@ -226,6 +254,12 @@ def send_joke(self) -> None:
226254
self._log.info("Sending joke", joke=joke)
227255
self._packet_manager.send(joke.encode("utf-8"))
228256

257+
def send_counter(self):
258+
"""Send the counter down so the ground station knows how to authenticate"""
259+
counter = str(self._last_command_counter)
260+
self._log.info("Sending Counter", counter=counter)
261+
self._packet_manager.send(counter.encode("utf-8"))
262+
229263
def change_radio_modulation(self, args: list[str]) -> None:
230264
"""Changes the radio modulation.
231265

circuitpython-workspaces/flight-software/src/pysquared/hmac_auth.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
class HMACAuthenticator:
2929
"""Provides HMAC authentication for command messages."""
3030

31-
def __init__(self, secret_key: str) -> None:
31+
def __init__(self, secret_key: str, hmac_class=HMAC) -> None:
3232
"""Initializes the HMACAuthenticator.
3333
3434
Args:
3535
secret_key: The shared secret key for HMAC generation and verification.
3636
"""
3737
self._secret_key: bytes = secret_key.encode("utf-8")
38+
self._hmac_class = hmac_class
3839

3940
def generate_hmac(self, message: str, counter: int) -> str:
4041
"""Generates an HMAC for a message with a counter.
@@ -47,38 +48,47 @@ def generate_hmac(self, message: str, counter: int) -> str:
4748
The HMAC as a hexadecimal string.
4849
"""
4950
# Combine message and counter
51+
print("generating hmac")
52+
print("hamc class", self._hmac_class)
5053
data = f"{message}|{counter}".encode("utf-8")
5154

5255
# Generate HMAC using SHA-256
5356
# Note: In CircuitPython, this uses the circuitpython_hmac library
5457
# In testing/CPython, this uses the standard hmac library
55-
h = HMAC(self._secret_key, data, hashlib.sha256)
58+
h = self._hmac_class(self._secret_key, data, hashlib.sha256)
5659
return h.hexdigest()
5760

61+
@staticmethod
5862
def compare_digest(expected_hmac: str, received_hmac: str):
5963
"""Compares two byte or str sequences in constant time.
60-
Returns True if expected_hmac == recieved_hmac, False otherwise.
64+
Returns True if expected_hmac == received_hmac, False otherwise.
6165
"""
62-
if not isinstance(expected_hmac, (bytes, bytearray, str)) or not isinstance(
63-
received_hmac, (bytes, bytearray, str)
64-
):
65-
raise TypeError("compare_digest() expects two bytes or two str objects")
66+
print("comparing digest")
67+
6668
# Convert strings to bytes if both are str
6769
if isinstance(expected_hmac, str) and isinstance(received_hmac, str):
6870
expected_hmac = expected_hmac.encode("utf-8")
6971
received_hmac = received_hmac.encode("utf-8")
70-
elif isinstance(expected_hmac, str) or isinstance(received_hmac, str):
71-
raise TypeError("Both inputs must be of the same type")
72+
73+
print("hehhehehee")
7274
# Ensure both are bytes/bytearray at this point
75+
print("expected hmac", expected_hmac)
76+
77+
print("expected hmac", type(expected_hmac))
7378
if len(expected_hmac) != len(received_hmac):
79+
print("lens are da same")
7480
# Continue processing full length to keep timing consistent
7581
result = 0
7682
maxlen = max(len(expected_hmac), len(received_hmac))
7783
for i in range(maxlen):
7884
x = expected_hmac[i] if i < len(expected_hmac) else 0
7985
y = received_hmac[i] if i < len(received_hmac) else 0
8086
result |= x ^ y
87+
print("returning False")
8188
return False
89+
else:
90+
print("not the if")
91+
print("result is")
8292
result = 0
8393
for x, y in zip(expected_hmac, received_hmac):
8494
result |= x ^ y
@@ -95,5 +105,9 @@ def verify_hmac(self, message: str, counter: int, received_hmac: str) -> bool:
95105
Returns:
96106
True if the HMAC is valid, False otherwise.
97107
"""
108+
print("verifying hmac")
98109
expected_hmac = self.generate_hmac(message, counter)
110+
print("expected hmac", expected_hmac)
111+
print("expected hmac type", type(expected_hmac))
112+
99113
return HMACAuthenticator.compare_digest(expected_hmac, received_hmac)

cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_cdh.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
initialization, command parsing, and execution of various commands.
66
"""
77

8+
import hmac
89
import json
910
from unittest.mock import MagicMock, patch
1011

@@ -57,6 +58,7 @@ def cdh(
5758
config=mock_config,
5859
packet_manager=mock_packet_manager,
5960
last_command_counter=mock_counter16,
61+
hmac_class=hmac.new,
6062
)
6163

6264

0 commit comments

Comments
 (0)