diff --git a/circuitpython-workspaces/flight-software/pyproject.toml b/circuitpython-workspaces/flight-software/pyproject.toml index a12dccc9..7ebc4662 100644 --- a/circuitpython-workspaces/flight-software/pyproject.toml +++ b/circuitpython-workspaces/flight-software/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "adafruit-circuitpython-ticks==1.1.1", "adafruit-circuitpython-veml7700==2.1.4", "adafruit-circuitpython-hashlib==1.4.19", + "circuitpython-hmac @ git+https://github.com/jimbobbennett/CircuitPython_HMAC.git", "proves-circuitpython-sx126 @ git+https://github.com/proveskit/micropySX126X@1.0.0", "proves-circuitpython-sx1280 @ git+https://github.com/proveskit/CircuitPython_SX1280@1.0.4", ] diff --git a/circuitpython-workspaces/flight-software/src/pysquared/cdh.py b/circuitpython-workspaces/flight-software/src/pysquared/cdh.py index 099fc2b2..d937f2e9 100644 --- a/circuitpython-workspaces/flight-software/src/pysquared/cdh.py +++ b/circuitpython-workspaces/flight-software/src/pysquared/cdh.py @@ -21,10 +21,19 @@ import traceback import microcontroller +from circuitpython_hmac import HMAC from .config.config import Config +from .config.jokes_config import JokesConfig from .hardware.radio.packetizer.packet_manager import PacketManager +from .hmac_auth import HMACAuthenticator from .logger import Logger +from .nvm.counter import Counter16 + +try: + from typing import Callable, Optional +except Exception: + pass class CommandDataHandler: @@ -33,6 +42,7 @@ class CommandDataHandler: command_reset: str = "reset" command_change_radio_modulation: str = "change_radio_modulation" command_send_joke: str = "send_joke" + command_get_counter: str = "get_counter" oscar_password: str = "Hello World!" # Default password for OSCAR commands @@ -41,7 +51,10 @@ def __init__( logger: Logger, config: Config, packet_manager: PacketManager, + jokes_config: JokesConfig, + last_command_counter: Optional[Counter16] = None, send_delay: float = 0.2, + hmac_class: Callable = HMAC, ) -> None: """Initializes the CommandDataHandler. @@ -49,12 +62,21 @@ def __init__( logger: The logger to use. config: The configuration to use. packet_manager: The packet manager to use for sending and receiving data. + last_command_counter: NVM counter tracking the last valid command counter (16-bit). send_delay: The delay between sending an acknowledgement and the response. """ self._log: Logger = logger self._config: Config = config + self._jokes_config: JokesConfig = jokes_config self._packet_manager: PacketManager = packet_manager self._send_delay: float = send_delay + self._hmac_authenticator: HMACAuthenticator = HMACAuthenticator( + config.hmac_secret, hmac_class=hmac_class + ) + if last_command_counter is not None: + self._last_command_counter: Counter16 = last_command_counter + else: + self._last_command_counter = Counter16(0) def listen_for_commands(self, timeout: int) -> None: """Listens for commands from the radio and handles them. @@ -66,14 +88,17 @@ def listen_for_commands(self, timeout: int) -> None: json_bytes = self._packet_manager.listen(timeout) if json_bytes is None: + self._log.debug("Nothing Found in the packet") return + self._log.debug("Found stuff in the packet") try: json_str = json_bytes.decode("utf-8") + self._log.debug("json string in the message is", stingis=json_str) msg: dict[str, str] = json.loads(json_str) - # Check for OSCAR password first + # Check for OSCAR password first (legacy authentication) if msg.get("password") == self.oscar_password: self._log.debug("OSCAR command received", msg=msg) cmd = msg.get("command") @@ -89,21 +114,108 @@ def listen_for_commands(self, timeout: int) -> None: if isinstance(raw_args, list): args: list[str] = raw_args - # Delay to give the ground station time to switch to listening mode - time.sleep(self._send_delay) - self._packet_manager.send_acknowledgement() - self.oscar_command(cmd, args) return - # If message has password field, check it - if msg.get("password") != self._config.super_secret_code: + # If message has command field, get the command + cmd = msg.get("command") + + if cmd is not None and cmd == self.command_get_counter: + self.send_counter() + return + + # HMAC-based authentication (required for non-OSCAR commands) + hmac_value = str(msg.get("hmac")) + counter_raw = msg.get("counter") + + self._log.debug( + "Current command, counter and hash", + command=cmd, + counter=counter_raw, + hasht=hmac_value, + ) + + # Require HMAC authentication + if hmac_value is None or counter_raw is None: self._log.debug( - "Invalid password in message", + "Missing HMAC or counter in message", msg=msg, ) return + # Use HMAC authentication + # Convert counter to int + self._log.debug("The counter is", counter=counter_raw) + + try: + counter: int = int(counter_raw) + except (ValueError, TypeError): + self._log.debug( + "Invalid counter in message", + counter=counter_raw, + ) + return + + # Validate counter is within 16-bit range + if counter < 0 or counter > 0xFFFF: + self._log.debug( + "Counter out of range", + counter=counter, + ) + return + + self._log.debug("counter validated") + + # Extract message without HMAC for verification + msg_without_hmac = {k: v for k, v in msg.items() if k != "hmac"} + message_str = json.dumps(msg_without_hmac, separators=(",", ":")) + + self._log.debug("messagestring is", msrg=message_str) + self._log.debug( + "hmac valid details", hmac=hmac_value, typeis=type(hmac_value) + ) + + # Verify HMAC + if not self._hmac_authenticator.verify_hmac( + message_str, counter, hmac_value + ): + self._log.debug( + "Invalid HMAC in message", + msg=msg, + ) + return + print("OUT") + self._log.debug("passed the authenticate compeint") + # Prevent replay attacks with wraparound handling + last_valid = self._last_command_counter.get() + self._log.debug("last valid is", lv=last_valid) + + # Check if counter is valid considering 16-bit wraparound + # Accept if counter is greater, or if wraparound occurred + # (counter is much smaller, indicating it wrapped around) + counter_diff = (counter - last_valid) & 0xFFFF + + # Valid if counter is within forward window (1 to 32768) + # This allows for wraparound while preventing replay attacks + if counter_diff == 0 or counter_diff > 0x8000: + self._log.debug( + "Replay attack detected - invalid counter", + counter=counter, + last_valid=last_valid, + diff=counter_diff, + ) + return + + # Update last valid counter in NVM + self._last_command_counter.set(counter) + + self._log.debug( + "Comparing names", + name1=msg.get("name"), + nameconfig=self._config.cubesat_name, + ) + + # Verify satellite name if msg.get("name") != self._config.cubesat_name: self._log.debug( "Satellite name mismatch in message", @@ -111,8 +223,8 @@ def listen_for_commands(self, timeout: int) -> None: ) return - # If message has command field, execute the command - cmd = msg.get("command") + self._log.debug("Names are the same") + if cmd is None: self._log.warning("No command found in message", msg=msg) self._packet_manager.send( @@ -120,6 +232,8 @@ def listen_for_commands(self, timeout: int) -> None: ) return + self._log.debug("COmmand is not none") + args: list[str] = [] raw_args = msg.get("args") if isinstance(raw_args, list): @@ -131,6 +245,8 @@ def listen_for_commands(self, timeout: int) -> None: time.sleep(self._send_delay) self._packet_manager.send_acknowledgement() + self._log.debug("Sent Acknowledgement", cmd=cmd, args=args) + if cmd == self.command_reset: self.reset() elif cmd == self.command_change_radio_modulation: @@ -154,10 +270,23 @@ def listen_for_commands(self, timeout: int) -> None: def send_joke(self) -> None: """Sends a random joke from the config.""" - joke = random.choice(self._config.jokes) + joke = random.choice(self._jokes_config.jokes) self._log.info("Sending joke", joke=joke) self._packet_manager.send(joke.encode("utf-8")) + def send_counter(self): + """Send the counter down so the ground station knows how to authenticate""" + + time.sleep(self._send_delay) + self._packet_manager.send_acknowledgement() + + # Additional delay to ensure ACK packet transmission completes before sending counter + time.sleep(self._send_delay) + + counter = str(self._last_command_counter.get()) + self._log.info("Sending Counter", counter=counter) + self._packet_manager.send(counter.encode("utf-8")) + def change_radio_modulation(self, args: list[str]) -> None: """Changes the radio modulation. @@ -191,6 +320,8 @@ def change_radio_modulation(self, args: list[str]) -> None: def reset(self) -> None: """Resets the hardware.""" + # Delay to give the ground station time to switch to listening mode + self._log.info("Resetting satellite") self._packet_manager.send(data="Resetting satellite".encode("utf-8")) microcontroller.on_next_reset(microcontroller.RunMode.NORMAL) @@ -203,6 +334,10 @@ def oscar_command(self, command: str, args: list[str]) -> None: command: The OSCAR command to execute. args: A list of arguments for the command. """ + time.sleep(self._send_delay) + + self._packet_manager.send_acknowledgement() + if command == "ping": self._log.info("OSCAR ping command received. Sending pong response.") self._packet_manager.send( diff --git a/circuitpython-workspaces/flight-software/src/pysquared/config/config.py b/circuitpython-workspaces/flight-software/src/pysquared/config/config.py index 40289b8e..45183a83 100644 --- a/circuitpython-workspaces/flight-software/src/pysquared/config/config.py +++ b/circuitpython-workspaces/flight-software/src/pysquared/config/config.py @@ -47,8 +47,9 @@ class Config: critical_battery_voltage (float): Critical battery voltage. reboot_time (int): Time before reboot in seconds. turbo_clock (bool): Turbo clock enabled flag. - super_secret_code (str): Secret code for special operations. + super_secret_code (str): Secret code for special operations (deprecated). repeat_code (str): Code for repeated operations. + hmac_secret (str): Shared secret for HMAC command authentication. longest_allowable_sleep_time (int): Maximum allowable sleep time. CONFIG_SCHEMA (dict): Validation schema for configuration keys. @@ -114,6 +115,7 @@ def __init__(self, config_path: str) -> None: self.turbo_clock: bool = json_data["turbo_clock"] self.super_secret_code: str = json_data["super_secret_code"] self.repeat_code: str = json_data["repeat_code"] + self.hmac_secret: str = json_data.get("hmac_secret", "default_hmac_secret") self.longest_allowable_sleep_time: int = json_data[ "longest_allowable_sleep_time" ] @@ -122,6 +124,7 @@ def __init__(self, config_path: str) -> None: "cubesat_name": {"type": str, "min_length": 1, "max_length": 10}, "super_secret_code": {"type": bytes, "min": 1, "max": 24}, "repeat_code": {"type": bytes, "min": 1, "max": 4}, + "hmac_secret": {"type": bytes, "min": 16, "max": 64}, "normal_charge_current": {"type": float, "min": 0.0, "max": 2000.0}, "normal_battery_voltage": {"type": float, "min": 6.0, "max": 8.4}, "degraded_battery_voltage": {"type": float, "min": 5.4, "max": 8.0}, diff --git a/circuitpython-workspaces/flight-software/src/pysquared/config/jokes_config.py b/circuitpython-workspaces/flight-software/src/pysquared/config/jokes_config.py new file mode 100644 index 00000000..45cc1481 --- /dev/null +++ b/circuitpython-workspaces/flight-software/src/pysquared/config/jokes_config.py @@ -0,0 +1,154 @@ +""" +This module provides the JokesConfig class, which handles loading, +validating, and updating jokes from a jokes.json file. + +Classes: + JokesConfig: Loads jokes from a JSON file, validates joke entries, + allows updating jokes (temporary or permanent), and saves changes. +""" + +import json + + +class JokesConfig: + """ + Jokes configuration handler. + + Loads jokes from a JSON file as a list of strings. Provides methods + to validate jokes, update them temporarily (RAM-only) or permanently + (saved back to file), and retrieve jokes. + + Attributes: + jokes_file (str): Path to the jokes JSON file. + jokes (list[str]): List of joke strings. + + Methods: + validate_joke(joke: str): + Validates a single joke string. + add_joke(joke: str, temporary: bool = True): + Adds a new joke. + update_joke(index: int, joke: str, temporary: bool = True): + Updates an existing joke by index. + remove_joke(index: int, temporary: bool = True): + Removes a joke by index. + save(): + Saves current jokes list to file. + """ + + def __init__(self, jokes_file: str) -> None: + """ + Initialize JokesConfig by loading jokes from a JSON file. + + Args: + jokes_file (str): Path to the jokes.json file. + + Raises: + FileNotFoundError: If the jokes file does not exist. + json.JSONDecodeError: If the file content is not valid JSON. + ValueError: If the JSON content is not a list of strings. + """ + self.jokes_file = jokes_file + + with open(self.jokes_file, "r") as f: + data = json.load(f) + + if not isinstance(data, list) or not all( + isinstance(joke, str) for joke in data + ): + raise ValueError("jokes.json must be a list of strings") + + self.jokes = data + + def validate_joke(self, joke: str) -> None: + """ + Validate a joke string. + + Args: + joke (str): The joke to validate. + + Raises: + TypeError: If joke is not a string. + ValueError: If joke is empty or too long. + """ + if not isinstance(joke, str): + raise TypeError("Joke must be a string") + if len(joke.strip()) == 0: + raise ValueError("Joke cannot be empty") + if len(joke) > 500: + raise ValueError("Joke is too long (max 500 characters)") + + def add_joke(self, joke: str, temporary: bool = True) -> None: + """ + Add a new joke to the list. + + Args: + joke (str): Joke string to add. + temporary (bool): If True, add only in RAM. If False, save to file. + """ + self.validate_joke(joke) + self.jokes.append(joke) + + if not temporary: + self.save() + + def update_joke(self, index: int, joke: str, temporary: bool = True) -> None: + """ + Update an existing joke by its index. + + Args: + index (int): Index of the joke to update. + joke (str): New joke string. + temporary (bool): If True, update only in RAM. If False, save to file. + + Raises: + IndexError: If index is out of range. + """ + if index < 0 or index >= len(self.jokes): + raise IndexError("Joke index out of range") + self.validate_joke(joke) + self.jokes[index] = joke + + if not temporary: + self.save() + + def remove_joke(self, index: int, temporary: bool = True) -> None: + """ + Remove a joke by index. + + Args: + index (int): Index of the joke to remove. + temporary (bool): If True, remove only in RAM. If False, save to file. + + Raises: + IndexError: If index is out of range. + """ + if index < 0 or index >= len(self.jokes): + raise IndexError("Joke index out of range") + del self.jokes[index] + + if not temporary: + self.save() + + def save(self) -> None: + """ + Save the current jokes list back to the JSON file. + """ + with open(self.jokes_file, "w") as f: + json.dump(self.jokes, f) + + def get_joke(self, index: int) -> str: + """ + Retrieve a joke by its index. + + Args: + index (int): Index of the joke to retrieve. + + Returns: + str: The joke string at the specified index. + + Raises: + IndexError: If index is out of range. + """ + if index < 0 or index >= len(self.jokes): + raise IndexError("Joke index out of range") + return self.jokes[index] diff --git a/circuitpython-workspaces/flight-software/src/pysquared/hardware/radio/packetizer/packet_manager.py b/circuitpython-workspaces/flight-software/src/pysquared/hardware/radio/packetizer/packet_manager.py index 49bee9e8..5a4842a0 100644 --- a/circuitpython-workspaces/flight-software/src/pysquared/hardware/radio/packetizer/packet_manager.py +++ b/circuitpython-workspaces/flight-software/src/pysquared/hardware/radio/packetizer/packet_manager.py @@ -167,11 +167,17 @@ def listen(self, timeout: Optional[int] = None) -> bytes | None: packet_identifier, _, total_packets, _ = self._get_header(packet) # Log received packets + payload = self._get_payload(packet) + try: + payload_str = payload.decode("utf-8") + except ValueError: + payload_str = payload + self._logger.debug( "Received packet", packet_length=len(packet), header=self._get_header(packet), - payload=self._get_payload(packet), + payload=payload_str, ) if received_packets: @@ -199,6 +205,7 @@ def listen(self, timeout: Optional[int] = None) -> bytes | None: def send_acknowledgement(self) -> None: """Sends an acknowledgment to the radio.""" self.send(b"ACK") + print("sending acknowledgment packet") self._logger.debug("Sent acknowledgment packet") def get_last_rssi(self) -> int: diff --git a/circuitpython-workspaces/flight-software/src/pysquared/hmac_auth.py b/circuitpython-workspaces/flight-software/src/pysquared/hmac_auth.py new file mode 100644 index 00000000..1881d410 --- /dev/null +++ b/circuitpython-workspaces/flight-software/src/pysquared/hmac_auth.py @@ -0,0 +1,100 @@ +"""This module provides HMAC-based authentication for command messages. + +This module implements HMAC (Hash-based Message Authentication Code) for +authenticating commands sent to the satellite. It provides protection against +unauthorized commands and replay attacks through the use of a shared secret +and packet counter. + +**Usage:** +```python +from pysquared.hmac_auth import HMACAuthenticator + +# On the ground station +authenticator = HMACAuthenticator("shared_secret_key") +message = '{"command": "send_joke", "name": "MySat"}' +counter = 42 +hmac_value = authenticator.generate_hmac(message, counter) + +# On the satellite +authenticator = HMACAuthenticator("shared_secret_key") +is_valid = authenticator.verify_hmac(message, counter, hmac_value) +``` +""" + +import adafruit_hashlib as hashlib # interesting, this lib imports cpython stuff if it's available.... hmmm +from circuitpython_hmac import HMAC + +try: + from typing import Callable +except Exception: + pass + + +class HMACAuthenticator: + """Provides HMAC authentication for command messages.""" + + def __init__(self, secret_key: str, hmac_class: Callable = HMAC) -> None: + """Initializes the HMACAuthenticator. + + Args: + secret_key: The shared secret key for HMAC generation and verification. + """ + self._secret_key: bytes = secret_key.encode("utf-8") + self._hmac_class = hmac_class + + def generate_hmac(self, message: str, counter: int) -> str: + """Generates an HMAC for a message with a counter. + + Args: + message: The message to authenticate. + counter: The packet counter for replay attack prevention. + + Returns: + The HMAC as a hexadecimal string. + """ + # Combine message and counter + data = f"{message}|{counter}".encode("utf-8") + + # Generate HMAC using SHA-256 + # Note: In CircuitPython, this uses the circuitpython_hmac library + # In testing/CPython, this uses the standard hmac library + h = self._hmac_class(self._secret_key, data, hashlib.sha256) + return h.hexdigest() + + @staticmethod + def compare_digest(expected_hmac: str, received_hmac: str): + """Compares str sequences in constant time. + Returns True if expected_hmac == received_hmac, False otherwise. + """ + + assert isinstance(expected_hmac, str) + assert isinstance(received_hmac, str) + + print("execpted", expected_hmac) + print("received", received_hmac) + + return expected_hmac == received_hmac + + def verify_hmac(self, message: str, counter: int, received_hmac: str) -> bool: + """Verifies an HMAC for a message with a counter. + + Args: + message: The message to verify. + counter: The packet counter for replay attack prevention. + received_hmac: The HMAC to verify. + + Returns: + True if the HMAC is valid, False otherwise. + """ + print( + "generate an HMAC with message:", + message, + "counter:", + counter, + "secret key", + self._secret_key, + ) + expected_hmac = self.generate_hmac(message, counter) + res = HMACAuthenticator.compare_digest(expected_hmac, received_hmac) + print(res) + return res diff --git a/circuitpython-workspaces/flight-software/src/pysquared/nvm/counter.py b/circuitpython-workspaces/flight-software/src/pysquared/nvm/counter.py index 8d4eddf7..81708802 100644 --- a/circuitpython-workspaces/flight-software/src/pysquared/nvm/counter.py +++ b/circuitpython-workspaces/flight-software/src/pysquared/nvm/counter.py @@ -55,3 +55,75 @@ def get_name(self) -> str: get_name returns the name of the counter """ return f"{self.__class__.__name__}_index_{self._index}" + + +class Counter16: + """ + Counter class for managing 16-bit counters stored in non-volatile memory. + + Uses two consecutive bytes in NVM to store a 16-bit counter value. + This provides a larger range (0-65535) before wraparound occurs. + + Attributes: + _index (int): The starting index of the counter in the NVM datastore. + _datastore (microcontroller.nvm.ByteArray): The NVM datastore. + """ + + def __init__( + self, + index: int, + ) -> None: + """ + Initializes a Counter16 instance. + + Args: + index (int): The starting index of the counter in the datastore. + Uses two consecutive bytes (index and index+1). + + Raises: + ValueError: If NVM is not available. + """ + self._index = index + + if microcontroller.nvm is None: + raise ValueError("nvm is not available") + + self._datastore = microcontroller.nvm + + def get(self) -> int: + """ + Returns the value of the counter. + + Returns: + int: The current value of the counter (0-65535). + """ + # Read two bytes: high byte at _index, low byte at _index+1 + high_byte = self._datastore[self._index] + low_byte = self._datastore[self._index + 1] + return (high_byte << 8) | low_byte + + def set(self, value: int) -> None: + """ + Sets the counter to a specific value. + + Args: + value: The value to set (0-65535). + """ + value = value & 0xFFFF # Ensure 16-bit value + high_byte = (value >> 8) & 0xFF + low_byte = value & 0xFF + self._datastore[self._index] = high_byte + self._datastore[self._index + 1] = low_byte + + def increment(self) -> None: + """ + Increases the counter by one, with 16-bit rollover. + """ + value: int = (self.get() + 1) & 0xFFFF # 16-bit counter with rollover + self.set(value) + + def get_name(self) -> str: + """ + get_name returns the name of the counter + """ + return f"{self.__class__.__name__}_index_{self._index}" diff --git a/circuitpython-workspaces/ground-station/src/ground_station/ground_station.py b/circuitpython-workspaces/ground-station/src/ground_station/ground_station.py index 20f1964b..85856329 100644 --- a/circuitpython-workspaces/ground-station/src/ground_station/ground_station.py +++ b/circuitpython-workspaces/ground-station/src/ground_station/ground_station.py @@ -9,6 +9,7 @@ from pysquared.cdh import CommandDataHandler from pysquared.config.config import Config from pysquared.hardware.radio.packetizer.packet_manager import PacketManager +from pysquared.hmac_auth import HMACAuthenticator from pysquared.logger import Logger @@ -21,12 +22,15 @@ def __init__( config: Config, packet_manager: PacketManager, cdh: CommandDataHandler, + starting_counter: int = 0, ): self._log = logger self._log.colorized = True self._config = config self._packet_manager = packet_manager self._cdh = cdh + self._hmac_authenticator = HMACAuthenticator(config.hmac_secret) + self._command_counter = starting_counter # Counter for replay attack prevention def listen(self): """Listen for incoming packets from the satellite.""" @@ -59,6 +63,7 @@ def send_receive(self): | 2: Change radio modulation | | 3: Send joke | | 4: OSCAR commands | + | 5: Request counter | =============================== """ ) @@ -75,7 +80,7 @@ def handle_input(self, cmd_selection): Args: cmd_selection: The command selection input by the user. """ - if cmd_selection not in ["1", "2", "3", "4"]: + if cmd_selection not in ["1", "2", "3", "4", "5"]: self._log.warning("Invalid command selection. Please try again.") return @@ -84,11 +89,20 @@ def handle_input(self, cmd_selection): self.handle_oscar_commands() return + if cmd_selection == "5": + self.handle_counter_request() + return + message: dict[str, object] = { "name": self._config.cubesat_name, - "password": self._config.super_secret_code, } + if self._command_counter == 0: + self._log.info( + "Command Counter not set, please request counter before sending commands" + ) + return + if cmd_selection == "1": message["command"] = self._cdh.command_reset elif cmd_selection == "2": @@ -98,15 +112,29 @@ def handle_input(self, cmd_selection): elif cmd_selection == "3": message["command"] = self._cdh.command_send_joke + # Increment counter for replay attack prevention + self._command_counter += 1 + message["counter"] = self._command_counter + + # Generate HMAC for the message + message_str = json.dumps(message, separators=(",", ":")) + print("gen hmac with GROUND STATION", message_str, self._command_counter) + hmac_value = self._hmac_authenticator.generate_hmac( + message_str, self._command_counter + ) + message["hmac"] = hmac_value + while True: # Turn on the radio so that it captures any received packets to buffer self._packet_manager.listen(1) # Send the message self._log.info( - "Sending command", + "\n________\nSending command NOW\n_________\n", cmd=message["command"], args=message.get("args", []), + counter=self._command_counter, + hmac=message["hmac"], ) self._packet_manager.send(json.dumps(message).encode("utf-8")) @@ -207,8 +235,59 @@ def handle_oscar_commands(self): except KeyboardInterrupt: self._log.debug("Keyboard interrupt received, exiting OSCAR mode.") + def handle_counter_request(self): + """ + Handle Counter Request by asking the satellite what its current counter is + """ + message: dict[str, object] = {"command": "get_counter"} + + try: + while True: + # Turn on the radio so that it captures any received packets to buffer + self._packet_manager.listen(1) + + # Send the OSCAR message + self._log.info( + "Sending counter request", + cmd=message["command"], + ) + self._packet_manager.send(json.dumps(message).encode("utf-8")) + + # Listen for ACK response + b = self._packet_manager.listen(1) + if b is None: + self._log.info("No response received, retrying...") + continue + + if b != b"ACK": + self._log.info( + "No ACK response received, retrying...", + response=b.decode("utf-8"), + ) + continue + + self._log.info("Received ACK") + + # Now listen for the actual response + b = self._packet_manager.listen(1) + if b is None: + self._log.info("No response received, retrying...") + continue + + self._log.info("Received counter response", response=b.decode("utf-8")) + current_counter = b.decode("utf-8") + self._command_counter = int(current_counter) + self._log.info("current counter set to", counter=current_counter) + + break + + except KeyboardInterrupt: + self._log.debug("Keyboard interrupt received, exiting OSCAR mode.") + def run(self): """Run the ground station interface.""" + # Prompt for starting counter value + while True: print( """ @@ -221,13 +300,14 @@ def run(self): | Please Select Your Mode | | 'A': Listen | | 'B': Send | + | 'C': Manually Set Counter | ============================= """ ) device_selection = input().lower() - if device_selection not in ["a", "b"]: + if device_selection not in ["a", "b", "c"]: self._log.warning("Invalid Selection. Please try again.") continue @@ -235,5 +315,26 @@ def run(self): self.listen() elif device_selection == "b": self.send_receive() - - time.sleep(1) + elif device_selection == "c": + while True: + try: + cmd_selection = input( + """ + ======================================= + | Type Counter Count you want to set | + ======================================= + > """ + ) + + cmd_selection = int(cmd_selection) # Convert input to in + self._command_counter = cmd_selection + self._log.debug(f"Command counter set to {cmd_selection}") + break # Exit loop after successful input + + except ValueError: + self._log.debug("Invalid input. Please enter an integer.") + except KeyboardInterrupt: + self._log.debug("Keyboard interrupt received, exiting.") + break + + time.sleep(1) diff --git a/cpython-workspaces/flight-software-unit-tests/src/config.json b/cpython-workspaces/flight-software-unit-tests/src/config.json index b3008f78..35107d5d 100644 --- a/cpython-workspaces/flight-software-unit-tests/src/config.json +++ b/cpython-workspaces/flight-software-unit-tests/src/config.json @@ -9,6 +9,7 @@ "detumble_enable_y": true, "detumble_enable_z": true, "heating": false, + "hmac_secret": "test_hmac_secret_key", "last_battery_temp": 20.0, "longest_allowable_sleep_time": 600, "normal_battery_temp": 1, diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/hardware/radio/packetizer/test_packet_manager.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/hardware/radio/packetizer/test_packet_manager.py index a37d9a00..44419107 100644 --- a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/hardware/radio/packetizer/test_packet_manager.py +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/hardware/radio/packetizer/test_packet_manager.py @@ -327,6 +327,8 @@ def test_receive_success(mock_time, mock_logger, mock_radio, mock_message_counte expected_data = b"first second" assert result == expected_data + print(mock_logger.debug.call_args_list) + # Verify proper logging mock_logger.debug.assert_any_call("Listening for data...", timeout=10) mock_logger.debug.assert_any_call( diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/nvm/test_counter16.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/nvm/test_counter16.py new file mode 100644 index 00000000..93d831e3 --- /dev/null +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/nvm/test_counter16.py @@ -0,0 +1,145 @@ +"""Unit tests for the Counter16 class.""" + +from unittest.mock import MagicMock, patch + +import pysquared.nvm.counter as counter +import pytest +from mocks.circuitpython.byte_array import ByteArray + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_init(mock_microcontroller: MagicMock): + """Tests Counter16 initialization. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + index = 0 + count = counter.Counter16(index) + assert count.get() == 0 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_independent_counters(mock_microcontroller: MagicMock): + """Tests that two counters maintain independent values. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count_1 = counter.Counter16(0) + count_2 = counter.Counter16(2) # Start at index 2 (skipping 0 and 1) + + count_2.increment() + assert count_1.get() == 0 + assert count_2.get() == 1 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_no_nvm(mock_microcontroller: MagicMock): + """Tests Counter16 initialization failure when NVM is not available. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + mock_microcontroller.nvm = None + with pytest.raises(ValueError): + counter.Counter16(0) + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_set_get(mock_microcontroller: MagicMock): + """Tests Counter16 set and get methods. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(0) + count.set(1234) + assert count.get() == 1234 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_increment(mock_microcontroller: MagicMock): + """Tests Counter16 increment. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(0) + count.set(0) + count.increment() + assert count.get() == 1 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_rollover(mock_microcontroller: MagicMock): + """Tests Counter16 16-bit rollover. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(0) + count.set(0xFFFF) # Max 16-bit value + count.increment() + assert count.get() == 0 # Should roll over to 0 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_large_value(mock_microcontroller: MagicMock): + """Tests Counter16 with large values. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(0) + count.set(65535) # Max 16-bit value + assert count.get() == 65535 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_multiple_increments(mock_microcontroller: MagicMock): + """Tests Counter16 multiple increments. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(0) + count.set(0) + for i in range(100): + count.increment() + assert count.get() == 100 + + +@patch("pysquared.nvm.counter.microcontroller") +def test_counter16_get_name(mock_microcontroller: MagicMock): + """Tests Counter16 get_name method. + + Args: + mock_microcontroller: Mocked microcontroller module. + """ + datastore = ByteArray(size=512) + mock_microcontroller.nvm = datastore + + count = counter.Counter16(5) + assert count.get_name() == "Counter16_index_5" diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_cdh.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_cdh.py index f8f1dcb6..d45bf563 100644 --- a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_cdh.py +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_cdh.py @@ -5,14 +5,17 @@ initialization, command parsing, and execution of various commands. """ +import hmac import json -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from pysquared.cdh import CommandDataHandler from pysquared.config.config import Config +from pysquared.config.jokes_config import JokesConfig from pysquared.hardware.radio.packetizer.packet_manager import PacketManager from pysquared.logger import Logger +from pysquared.nvm.counter import Counter16 @pytest.fixture @@ -33,33 +36,62 @@ def mock_config() -> Config: config = MagicMock(spec=Config) config.super_secret_code = "test_password" config.cubesat_name = "test_satellite" - config.jokes = ["Why did the satellite cross the orbit? To get to the other side!"] + config.hmac_secret = "test_hmac_secret" return config @pytest.fixture -def cdh(mock_logger, mock_config, mock_packet_manager) -> CommandDataHandler: +def mock_joke_config() -> JokesConfig: + """Mocks the JokesConfig class""" + joke_config = MagicMock(spec=JokesConfig) + joke_config.jokes = ["Why did the the chicken cross the asteroid belt"] + return joke_config + + +@pytest.fixture +def mock_counter16() -> Counter16: + """Mocks the Counter16 class.""" + counter = MagicMock(spec=Counter16) + counter.get.return_value = 0 # Default to 0 + return counter + + +@pytest.fixture +def cdh( + mock_logger, mock_config, mock_packet_manager, mock_counter16, mock_joke_config +) -> CommandDataHandler: """Provides a CommandDataHandler instance for testing.""" return CommandDataHandler( logger=mock_logger, config=mock_config, packet_manager=mock_packet_manager, + jokes_config=mock_joke_config, + last_command_counter=mock_counter16, + hmac_class=hmac.new, ) -def test_cdh_init(mock_logger, mock_config, mock_packet_manager): +def test_cdh_init( + mock_logger, mock_config, mock_packet_manager, mock_joke_config, mock_counter16 +): """Tests CommandDataHandler initialization. Args: mock_logger: Mocked Logger instance. mock_config: Mocked Config instance. mock_packet_manager: Mocked PacketManager instance. + mock_joke_config: Mocked joke confog + mock_counter16: Mocked Counter16 instance. """ - cdh = CommandDataHandler(mock_logger, mock_config, mock_packet_manager) + cdh = CommandDataHandler( + mock_logger, mock_config, mock_packet_manager, mock_joke_config, mock_counter16 + ) assert cdh._log is mock_logger assert cdh._config is mock_config assert cdh._packet_manager is mock_packet_manager + assert cdh._last_command_counter is mock_counter16 + assert cdh._jokes_config is mock_joke_config def test_listen_for_commands_no_message(cdh, mock_packet_manager): @@ -77,69 +109,77 @@ def test_listen_for_commands_no_message(cdh, mock_packet_manager): # If no message received, function should simply return -def test_listen_for_commands_invalid_password(cdh, mock_packet_manager, mock_logger): - """Tests listen_for_commands with invalid password. +def test_listen_for_commands_missing_hmac(cdh, mock_packet_manager, mock_logger): + """Tests listen_for_commands with missing HMAC/counter. Args: cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_logger: Mocked Logger instance. """ - # Create a message with wrong password + # Create a message without HMAC or counter (old password-based) message = {"password": "wrong_password", "command": "send_joke", "args": []} mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") cdh.listen_for_commands(30) mock_packet_manager.listen.assert_called_once_with(30) - mock_logger.debug.assert_any_call("Invalid password in message", msg=message) + mock_logger.debug.assert_any_call("Missing HMAC or counter in message", msg=message) -def test_listen_for_commands_invalid_name(cdh, mock_packet_manager, mock_logger): - """Tests listen_for_commands with missing command field. +def test_listen_for_commands_missing_hmac_name_check( + cdh, mock_packet_manager, mock_logger +): + """Tests listen_for_commands with missing HMAC (satellite name irrelevant without HMAC). Args: cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_logger: Mocked Logger instance. """ - # Create a message with valid password and satellite name but no command + # Create a message without HMAC message = {"password": "test_password", "name": "wrong_name", "args": []} mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") cdh.listen_for_commands(30) mock_packet_manager.listen.assert_called_once_with(30) - mock_logger.debug.assert_any_call("Satellite name mismatch in message", msg=message) + # Should reject due to missing HMAC, not name mismatch + mock_logger.debug.assert_any_call("Missing HMAC or counter in message", msg=message) -def test_listen_for_commands_missing_command(cdh, mock_packet_manager, mock_logger): - """Tests listen_for_commands with missing command field. +def test_listen_for_commands_missing_command_no_hmac( + cdh, mock_packet_manager, mock_logger +): + """Tests listen_for_commands with missing HMAC. Args: cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_logger: Mocked Logger instance. """ - # Create a message with valid password but no command + # Create a message without HMAC message = {"password": "test_password", "name": "test_satellite", "args": []} mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") cdh.listen_for_commands(30) mock_packet_manager.listen.assert_called_once_with(30) - mock_logger.warning.assert_any_call("No command found in message", msg=message) + # Should reject due to missing HMAC + mock_logger.debug.assert_any_call("Missing HMAC or counter in message", msg=message) -def test_listen_for_commands_nonlist_args(cdh, mock_packet_manager, mock_logger): - """Tests listen_for_commands with missing command field. +def test_listen_for_commands_nonlist_args_no_hmac( + cdh, mock_packet_manager, mock_logger +): + """Tests listen_for_commands with missing HMAC. Args: cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_logger: Mocked Logger instance. """ - # Create a message with valid password but no command + # Create a message without HMAC message = { "password": "test_password", "name": "test_satellite", @@ -151,9 +191,8 @@ def test_listen_for_commands_nonlist_args(cdh, mock_packet_manager, mock_logger) cdh.listen_for_commands(30) mock_packet_manager.listen.assert_called_once_with(30) - mock_logger.debug.assert_any_call( - "Received command message", cmd="send_joke", args=[] - ) + # Should reject due to missing HMAC + mock_logger.debug.assert_any_call("Missing HMAC or counter in message", msg=message) def test_listen_for_commands_invalid_json(cdh, mock_packet_manager, mock_logger): @@ -175,7 +214,7 @@ def test_listen_for_commands_invalid_json(cdh, mock_packet_manager, mock_logger) @patch("random.choice") -def test_send_joke(mock_random_choice, cdh, mock_packet_manager, mock_config): +def test_send_joke(mock_random_choice, cdh, mock_packet_manager, mock_joke_config): """Tests the send_joke method. Args: @@ -184,13 +223,13 @@ def test_send_joke(mock_random_choice, cdh, mock_packet_manager, mock_config): mock_packet_manager: Mocked PacketManager instance. mock_config: Mocked Config instance. """ - mock_random_choice.return_value = mock_config.jokes[0] + mock_random_choice.return_value = mock_joke_config.jokes[0] cdh.send_joke() - mock_random_choice.assert_called_once_with(mock_config.jokes) + mock_random_choice.assert_called_once_with(mock_joke_config.jokes) mock_packet_manager.send.assert_called_once_with( - mock_config.jokes[0].encode("utf-8") + mock_joke_config.jokes[0].encode("utf-8") ) @@ -291,18 +330,24 @@ def test_listen_for_commands_reset( mock_microcontroller: Mocked microcontroller module. cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. """ # Set up mocked attributes mock_microcontroller.reset = MagicMock() mock_microcontroller.on_next_reset = MagicMock() - message = { - "password": "test_password", + counter = 1 + message_dict = { "name": "test_satellite", "command": "reset", "args": [], + "counter": counter, } - mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") cdh.listen_for_commands(30) @@ -315,7 +360,12 @@ def test_listen_for_commands_reset( @patch("time.sleep") @patch("random.choice") def test_listen_for_commands_send_joke( - mock_random_choice, mock_sleep, cdh, mock_packet_manager, mock_config + mock_random_choice, + mock_sleep, + cdh, + mock_packet_manager, + mock_joke_config, + mock_counter16, ): """Tests listen_for_commands with send_joke command. @@ -324,26 +374,32 @@ def test_listen_for_commands_send_joke( cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_config: Mocked Config instance. + mock_counter16: Mocked Counter16 instance. """ - message = { - "password": "test_password", + counter = 1 + message_dict = { "name": "test_satellite", "command": "send_joke", "args": [], + "counter": counter, } - mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") - mock_random_choice.return_value = mock_config.jokes[0] + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + mock_random_choice.return_value = mock_joke_config.jokes[0] cdh.listen_for_commands(30) mock_packet_manager.send.assert_called_once_with( - mock_config.jokes[0].encode("utf-8") + mock_joke_config.jokes[0].encode("utf-8") ) @patch("time.sleep") def test_listen_for_commands_change_radio_modulation( - mock_sleep, cdh, mock_packet_manager, mock_config + mock_sleep, cdh, mock_packet_manager, mock_config, mock_counter16 ): """Tests listen_for_commands with change_radio_modulation command. @@ -351,14 +407,20 @@ def test_listen_for_commands_change_radio_modulation( cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_config: Mocked Config instance. + mock_counter16: Mocked Counter16 instance. """ - message = { - "password": "test_password", + counter = 1 + message_dict = { "name": "test_satellite", "command": "change_radio_modulation", "args": ["FSK"], + "counter": counter, } - mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") cdh.listen_for_commands(30) @@ -369,7 +431,7 @@ def test_listen_for_commands_change_radio_modulation( @patch("time.sleep") def test_listen_for_commands_unknown_command( - mock_sleep, cdh, mock_packet_manager, mock_logger + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 ): """Tests listen_for_commands with an unknown command. @@ -377,14 +439,20 @@ def test_listen_for_commands_unknown_command( cdh: CommandDataHandler instance. mock_packet_manager: Mocked PacketManager instance. mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. """ - message = { - "password": "test_password", + counter = 1 + message_dict = { "name": "test_satellite", "command": "unknown_command", "args": [], + "counter": counter, } - mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") cdh.listen_for_commands(30) @@ -608,3 +676,343 @@ def test_listen_for_commands_oscar_ping_integration( # Verify ping response was sent mock_packet_manager.send.assert_called_once_with("Pong! -82".encode("utf-8")) + + +# HMAC Authentication Tests + + +@patch("time.sleep") +def test_listen_for_commands_valid_hmac( + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 +): + """Tests listen_for_commands with valid HMAC authentication. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. + """ + counter = 1 + message_dict = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + + # Generate valid HMAC + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify acknowledgement was sent + mock_packet_manager.send_acknowledgement.assert_called_once() + + # Verify command was executed + mock_packet_manager.send.assert_called_once() + + # Verify counter was updated in NVM + mock_counter16.set.assert_called_once_with(counter) + + +@patch("time.sleep") +def test_listen_for_commands_invalid_hmac( + mock_sleep, cdh, mock_packet_manager, mock_logger +): + """Tests listen_for_commands with invalid HMAC. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + """ + counter = 2 + message = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + "hmac": "invalid_hmac_value_0000000000000000000000000000000000000000", + } + + mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify the message was rejected + mock_logger.debug.assert_any_call("Invalid HMAC in message", msg=message) + + # Verify acknowledgement was NOT sent + mock_packet_manager.send_acknowledgement.assert_not_called() + + +@patch("time.sleep") +def test_listen_for_commands_replay_attack( + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 +): + """Tests that replay attacks are prevented. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. + """ + # First, send a valid command + counter1 = 10 + mock_counter16.get.return_value = 0 # Initial counter value + + message_dict1 = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter1, + } + message_str1 = json.dumps(message_dict1, separators=(",", ":")) + hmac_value1 = cdh._hmac_authenticator.generate_hmac(message_str1, counter1) + message_dict1["hmac"] = hmac_value1 + + mock_packet_manager.listen.return_value = json.dumps(message_dict1).encode("utf-8") + cdh.listen_for_commands(30) + + # Verify counter was updated + mock_counter16.set.assert_called_with(counter1) + + # Now set the counter to the stored value for replay + mock_counter16.get.return_value = counter1 + + # Try to replay the same command (same counter) + mock_packet_manager.listen.return_value = json.dumps(message_dict1).encode("utf-8") + cdh.listen_for_commands(30) + + # Verify the replay was rejected (counter_diff == 0) + mock_logger.debug.assert_any_call( + "Replay attack detected - invalid counter", + counter=counter1, + last_valid=counter1, + diff=0, + ) + + +@patch("time.sleep") +def test_listen_for_commands_old_counter( + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 +): + """Tests that old counter values are rejected. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. + """ + # Set last valid counter to 20 + mock_counter16.get.return_value = 20 + + # Try to send a command with counter 15 (older than last valid) + counter = 15 + message_dict = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify the command was rejected (counter_diff > 0x8000 means backwards) + mock_logger.debug.assert_any_call( + "Replay attack detected - invalid counter", + counter=counter, + last_valid=20, + diff=(15 - 20) & 0xFFFF, # 65531 which is > 0x8000 + ) + + # Verify acknowledgement was NOT sent + mock_packet_manager.send_acknowledgement.assert_not_called() + + +@patch("time.sleep") +def test_listen_for_commands_hmac_with_wrong_satellite_name( + mock_sleep, cdh, mock_packet_manager, mock_logger +): + """Tests HMAC authentication with wrong satellite name. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + """ + counter = 30 + message_dict = { + "name": "wrong_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify the message was rejected due to name mismatch + mock_logger.debug.assert_any_call( + "Satellite name mismatch in message", msg=message_dict + ) + + # Verify acknowledgement was NOT sent + mock_packet_manager.send_acknowledgement.assert_not_called() + + +@patch("time.sleep") +@patch("pysquared.cdh.microcontroller") +def test_listen_for_commands_reset_with_hmac( + mock_microcontroller, mock_sleep, cdh, mock_packet_manager, mock_logger +): + """Tests listen_for_commands with reset command using HMAC authentication. + + Args: + mock_microcontroller: Mocked microcontroller module. + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + """ + # Set up mocked attributes + mock_microcontroller.reset = MagicMock() + mock_microcontroller.on_next_reset = MagicMock() + + counter = 40 + message_dict = { + "name": "test_satellite", + "command": "reset", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify acknowledgement was sent + mock_packet_manager.send_acknowledgement.assert_called_once() + + # Verify reset was called + mock_microcontroller.on_next_reset.assert_called_once_with( + mock_microcontroller.RunMode.NORMAL + ) + mock_microcontroller.reset.assert_called_once() + + +@patch("time.sleep") +def test_listen_for_commands_counter_wraparound( + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 +): + """Tests that counter wraparound is handled correctly. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. + """ + # Set last valid counter near max value + mock_counter16.get.return_value = 65530 + + # Send command with wrapped counter (small value) + counter = 5 + message_dict = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify command was accepted (wraparound handled correctly) + mock_packet_manager.send_acknowledgement.assert_called_once() + mock_counter16.set.assert_called_once_with(counter) + + +@patch("time.sleep") +def test_listen_for_commands_counter_out_of_range( + mock_sleep, cdh, mock_packet_manager, mock_logger, mock_counter16 +): + """Tests that out-of-range counter values are rejected. + + Args: + mock_sleep: Mocked time.sleep function. + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mock_counter16: Mocked Counter16 instance. + """ + # Send command with out-of-range counter + counter = 70000 # > 65535 + message_dict = { + "name": "test_satellite", + "command": "send_joke", + "args": [], + "counter": counter, + } + message_str = json.dumps(message_dict, separators=(",", ":")) + hmac_value = cdh._hmac_authenticator.generate_hmac(message_str, counter) + message_dict["hmac"] = hmac_value + + mock_packet_manager.listen.return_value = json.dumps(message_dict).encode("utf-8") + + cdh.listen_for_commands(30) + + # Verify command was rejected + mock_logger.debug.assert_any_call("Counter out of range", counter=counter) + mock_packet_manager.send_acknowledgement.assert_not_called() + + +def test_listen_for_commands_triggers_send_counter( + cdh, mock_packet_manager, mock_logger +): + """Tests that listen_for_commands triggers send_counter when command is 'get_counter'. + + Args: + cdh: CommandDataHandler instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + mocker: Pytest-mock fixture for patching. + """ + # Prepare the message with command "get_counter" + message = {"command": cdh.command_get_counter} + mock_packet_manager.listen.return_value = json.dumps(message).encode("utf-8") + + # Patch send_counter method to monitor its call + + cdh.listen_for_commands(30) + + mock_logger.info.assert_any_call("Sending Counter", counter=ANY) diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_auth.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_auth.py new file mode 100644 index 00000000..10747727 --- /dev/null +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_auth.py @@ -0,0 +1,194 @@ +"""Unit tests for the HMACAuthenticator class. + +This module contains unit tests for the `HMACAuthenticator` class, which is +responsible for generating and verifying HMAC signatures for command messages. +""" + +import hmac + +from pysquared.hmac_auth import HMACAuthenticator + + +def test_hmac_authenticator_init(): + """ + This initizalises the HMACauthenticator and ensures the correct safety key is saved + """ + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + assert authenticator._secret_key == secret_key.encode("utf-8") + + +def test_generate_hmac(): + """ + this ensures the generate hmac function reyurns a string of the correct size and format + """ + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "send_joke", "name": "TestSat"}' + counter = 42 + + hmac_value = authenticator.generate_hmac(message, counter) + + # Check the HMAC is valid hex and 64 characters + assert isinstance(hmac_value, str) + assert len(hmac_value) == 64 + assert all(c in "0123456789abcdef" for c in hmac_value) + + +def test_generate_hmac_consistency(): + """Tests that HMAC generation is consistent for the same inputs.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "reset"}' + counter = 100 + + hmac1 = authenticator.generate_hmac(message, counter) + hmac2 = authenticator.generate_hmac(message, counter) + + assert hmac1 == hmac2 + + +def test_generate_hmac_different_messages(): + """Tests that different messages produce different HMACs.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message1 = '{"command": "send_joke"}' + message2 = '{"command": "reset"}' + counter = 50 + + hmac1 = authenticator.generate_hmac(message1, counter) + hmac2 = authenticator.generate_hmac(message2, counter) + + assert hmac1 != hmac2 + + +def test_generate_hmac_different_counters(): + """Tests that different counters produce different HMACs.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "send_joke"}' + counter1 = 10 + counter2 = 11 + + hmac1 = authenticator.generate_hmac(message, counter1) + hmac2 = authenticator.generate_hmac(message, counter2) + + assert hmac1 != hmac2 + + +def test_generate_hmac_different_secrets(): + """Tests that different secrets produce different HMACs.""" + message = '{"command": "send_joke"}' + counter = 25 + + authenticator1 = HMACAuthenticator("secret1", hmac_class=hmac.new) + authenticator2 = HMACAuthenticator("secret2", hmac_class=hmac.new) + + hmac1 = authenticator1.generate_hmac(message, counter) + hmac2 = authenticator2.generate_hmac(message, counter) + + assert hmac1 != hmac2 + + +def test_verify_hmac_valid(): + """Tests HMAC verification with valid HMAC.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "send_joke"}' + counter = 75 + + hmac_value = authenticator.generate_hmac(message, counter) + is_valid = authenticator.verify_hmac(message, counter, hmac_value) + + assert is_valid is True + + +def test_verify_hmac_invalid(): + """Tests HMAC verification with invalid HMAC.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "send_joke"}' + counter = 75 + + # Use a fake HMAC + fake_hmac = "0" * 64 + is_valid = authenticator.verify_hmac(message, counter, fake_hmac) + + assert is_valid is False + + +def test_verify_hmac_wrong_message(): + """Tests HMAC verification fails when message is modified.""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + original_message = '{"command": "send_joke"}' + modified_message = '{"command": "reset"}' + counter = 80 + + hmac_value = authenticator.generate_hmac(original_message, counter) + is_valid = authenticator.verify_hmac(modified_message, counter, hmac_value) + + assert is_valid is False + + +def test_verify_hmac_wrong_counter(): + """Tests HMAC verification fails when counter is modified (replay attack).""" + secret_key = "test_secret" + authenticator = HMACAuthenticator(secret_key, hmac_class=hmac.new) + + message = '{"command": "send_joke"}' + original_counter = 90 + modified_counter = 89 + + hmac_value = authenticator.generate_hmac(message, original_counter) + is_valid = authenticator.verify_hmac(message, modified_counter, hmac_value) + + assert is_valid is False + + +def test_verify_hmac_wrong_secret(): + """Tests HMAC verification fails when secret is different.""" + message = '{"command": "send_joke"}' + counter = 95 + + authenticator1 = HMACAuthenticator("secret1", hmac_class=hmac.new) + authenticator2 = HMACAuthenticator("secret2", hmac_class=hmac.new) + + hmac_value = authenticator1.generate_hmac(message, counter) + is_valid = authenticator2.verify_hmac(message, counter, hmac_value) + + assert is_valid is False + + +def test_compare_digest_equal_strings(): + """Returns True when strings are identical.""" + assert HMACAuthenticator.compare_digest("abcdef", "abcdef") is True + + +def test_compare_digest_different_strings_same_length(): + """Returns False when strings differ but are same length.""" + assert HMACAuthenticator.compare_digest("abcdef", "abcdeg") is False + + +def test_compare_digest_different_length_strings(): + """Returns False when strings are of different lengths.""" + assert HMACAuthenticator.compare_digest("abcdef", "abcde") is False + + +def test_compare_digest_empty_strings(): + """Returns True when both are empty strings.""" + assert HMACAuthenticator.compare_digest("", "") is True + + +def test_compare_digest_one_empty_string(): + """Returns False when one is empty and the other is not.""" + assert HMACAuthenticator.compare_digest("", "a") is False + assert HMACAuthenticator.compare_digest("a", "") is False diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_integration.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_integration.py new file mode 100644 index 00000000..229c8de6 --- /dev/null +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_hmac_integration.py @@ -0,0 +1,445 @@ +"""Integration tests for HMAC command authentication between ground station and flight software. + +This module contains integration tests that verify the complete HMAC authentication +flow between the ground station and the satellite's command data handler. +""" + +import hmac +import json +from unittest.mock import MagicMock, patch + +import pytest +from pysquared.cdh import CommandDataHandler +from pysquared.config.config import Config +from pysquared.config.jokes_config import JokesConfig +from pysquared.hardware.radio.packetizer.packet_manager import PacketManager +from pysquared.hmac_auth import HMACAuthenticator +from pysquared.logger import Logger +from pysquared.nvm.counter import Counter16 + + +@pytest.fixture +def mock_logger(): + """Mocks the Logger class.""" + return MagicMock(spec=Logger) + + +@pytest.fixture +def mock_packet_manager(): + """Mocks the PacketManager class.""" + return MagicMock(spec=PacketManager) + + +@pytest.fixture +def mock_config(): + """Mocks the Config class.""" + config = MagicMock(spec=Config) + config.cubesat_name = "TestSat" + config.hmac_secret = "shared_secret_key_123" + config.jokes = ["Why did the satellite go to school? To improve its orbit-ude!"] + return config + + +@pytest.fixture +def mock_counter16(): + """Mocks the Counter16 class.""" + counter = MagicMock(spec=Counter16) + counter.get.return_value = 0 # Start at 0 + return counter + + +@pytest.fixture +def mock_joke_config() -> JokesConfig: + """Mocks the JokesConfig class""" + joke_config = MagicMock(spec=JokesConfig) + joke_config.jokes = ["Why did the the chicken cross the asteroid belt"] + return joke_config + + +@pytest.fixture +def flight_software_cdh( + mock_logger, mock_config, mock_packet_manager, mock_joke_config, mock_counter16 +): + """Provides a CommandDataHandler instance (flight software side).""" + return CommandDataHandler( + logger=mock_logger, + config=mock_config, + packet_manager=mock_packet_manager, + jokes_config=mock_joke_config, + last_command_counter=mock_counter16, + hmac_class=hmac.new, + ) + + +@pytest.fixture +def ground_station_authenticator(mock_config): + """Provides an HMACAuthenticator instance (ground station side).""" + return HMACAuthenticator(mock_config.hmac_secret, hmac_class=hmac.new) + + +def test_hmac_integration_valid_command( + flight_software_cdh, + ground_station_authenticator, + mock_joke_config, + mock_packet_manager, + mock_counter16, +): + """Tests successful HMAC authentication flow from ground station to flight software. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + """ + # Ground station creates a command + gs_counter = 1 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + # Ground station generates HMAC + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + command_message["hmac"] = hmac_value + + # Ground station sends the command (simulated) + command_bytes = json.dumps(command_message).encode("utf-8") + + # Flight software receives the command + mock_packet_manager.listen.return_value = command_bytes + + # Flight software processes the command + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify the command was accepted + mock_packet_manager.send_acknowledgement.assert_called_once() + + # Verify counter was updated in NVM + mock_counter16.set.assert_called_once_with(gs_counter) + + # Verify the joke was sent + mock_packet_manager.send.assert_called_once() + sent_data = mock_packet_manager.send.call_args[0][0] + assert sent_data == mock_joke_config.jokes[0].encode("utf-8") + + +def test_hmac_integration_invalid_hmac( + flight_software_cdh, + ground_station_authenticator, + mock_config, + mock_packet_manager, + mock_logger, +): + """Tests that flight software rejects commands with invalid HMAC. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_logger: Mocked Logger instance. + """ + # Ground station creates a command + gs_counter = 2 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + # Attacker modifies the command but keeps the original HMAC + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + + # Modify the command (tampering) + command_message["command"] = "reset" # Changed by attacker + command_message["hmac"] = id(hmac_value) # Original HMAC (now invalid) + + # Attacker sends the tampered command + command_bytes = json.dumps(command_message).encode("utf-8") + + # Flight software receives the tampered command + mock_packet_manager.listen.return_value = command_bytes + + # Flight software processes the command + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify the command was rejected + mock_logger.debug.assert_any_call("Invalid HMAC in message", msg=command_message) + + mock_packet_manager.send_acknowledgement.assert_not_called() + + +def test_hmac_integration_replay_attack( + flight_software_cdh, + ground_station_authenticator, + mock_config, + mock_packet_manager, + mock_counter16, + mock_logger, +): + """Tests that flight software prevents replay attacks. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + mock_logger: Mocked Logger instance. + """ + # Ground station sends first command + gs_counter = 10 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + command_message["hmac"] = hmac_value + command_bytes = json.dumps(command_message).encode("utf-8") + + # Flight software receives and accepts first command + mock_packet_manager.listen.return_value = command_bytes + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify first command was accepted + assert mock_counter16.set.call_count == 1 + + # Update counter to simulate NVM storage + mock_counter16.get.return_value = gs_counter + + # Attacker tries to replay the same command + mock_packet_manager.listen.return_value = command_bytes + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify replay was rejected + mock_logger.debug.assert_any_call( + "Replay attack detected - invalid counter", + counter=gs_counter, + last_valid=gs_counter, + diff=0, + ) + + +def test_hmac_integration_counter_sequence( + flight_software_cdh, + ground_station_authenticator, + mock_config, + mock_packet_manager, + mock_counter16, +): + """Tests multiple commands with incrementing counters. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + """ + # Simulate sending multiple commands with incrementing counters + for gs_counter in [1, 2, 3, 4, 5]: + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + command_message["hmac"] = hmac_value + command_bytes = json.dumps(command_message).encode("utf-8") + + # Update the counter to reflect previous successful command + if gs_counter > 1: + mock_counter16.get.return_value = gs_counter - 1 + + mock_packet_manager.listen.return_value = command_bytes + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify counter was updated + assert mock_counter16.set.call_args[0][0] == gs_counter + + # All 5 commands should have been accepted + assert mock_counter16.set.call_count == 5 + + +def test_hmac_integration_counter_wraparound( + flight_software_cdh, + ground_station_authenticator, + mock_config, + mock_packet_manager, + mock_counter16, +): + """Tests counter wraparound handling in integration scenario. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + """ + # Set counter near max value (simulating many commands sent) + mock_counter16.get.return_value = 65530 + + # Ground station sends command with wrapped counter + gs_counter = 10 # Wrapped around from 65535 -> 0 -> 10 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + command_message["hmac"] = hmac_value + command_bytes = json.dumps(command_message).encode("utf-8") + + mock_packet_manager.listen.return_value = command_bytes + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify wraparound was handled correctly + mock_packet_manager.send_acknowledgement.assert_called_once() + mock_counter16.set.assert_called_once_with(gs_counter) + + +def test_hmac_integration_different_secrets( + mock_logger, + mock_config, + mock_packet_manager, + mock_joke_config, + mock_counter16, +): + """Tests that flight software rejects commands with different HMAC secret. + + Args: + mock_logger: Mocked Logger instance. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + """ + # Flight software with one secret + + flight_software_cdh = CommandDataHandler( + logger=mock_logger, + config=mock_config, + packet_manager=mock_packet_manager, + jokes_config=mock_joke_config, + last_command_counter=mock_counter16, + hmac_class=hmac.new, + ) + + # Ground station with different secret (attacker scenario) + wrong_key = "wrong_secret_key" + wrong_authenticator = HMACAuthenticator(wrong_key, hmac_class=hmac.new) + + gs_counter = 1 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [], + "counter": gs_counter, + } + + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = wrong_authenticator.generate_hmac(message_str, gs_counter) + + command_message["hmac"] = id(hmac_value) + command_bytes = json.dumps(command_message).encode("utf-8") + + mock_packet_manager.listen.return_value = command_bytes + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + assert mock_logger.debug.called, "Logger.debug was never called" + assert flight_software_cdh._config.hmac_secret != wrong_key + + # Verify command was rejected due to wrong HMAC + mock_packet_manager.send_acknowledgement.assert_not_called() + mock_logger.debug.assert_any_call("Invalid HMAC in message", msg=command_message) + + +def test_hmac_integration_large_message( + flight_software_cdh, + ground_station_authenticator, + mock_config, + mock_packet_manager, + mock_counter16, +): + """Tests HMAC authentication with large message requiring packetization. + + This test verifies that HMAC authentication works correctly even when + the message is larger than the packet size (252 bytes) and needs to be + broken into multiple packets by the packet_manager. + + Args: + flight_software_cdh: Flight software CDH instance. + ground_station_authenticator: Ground station HMAC authenticator. + mock_config: Mocked Config instance. + mock_packet_manager: Mocked PacketManager instance. + mock_counter16: Mocked Counter16 instance. + """ + # Create a large command message (approximately 10kB) + # Generate a large argument list to make the message size > 10kB + large_data = "A" * 10000 # 10kB of data + gs_counter = 1 + command_message = { + "name": "TestSat", + "command": "send_joke", + "args": [large_data], # Large argument + "counter": gs_counter, + } + + # Ground station generates HMAC for the complete message + message_str = json.dumps(command_message, separators=(",", ":")) + hmac_value = ground_station_authenticator.generate_hmac(message_str, gs_counter) + command_message["hmac"] = hmac_value + + # The complete message as bytes + command_bytes = json.dumps(command_message).encode("utf-8") + + # Verify the message is indeed large (> 252 bytes, typical packet size) + assert len(command_bytes) > 252, ( + f"Message size {len(command_bytes)} should be > 252" + ) + assert len(command_bytes) > 10000, ( + f"Message size {len(command_bytes)} should be > 10kB" + ) + + # In real scenario, packet_manager would fragment this message + # For this test, we simulate that the complete message is reassembled + # and returned by packet_manager.listen() + mock_packet_manager.listen.return_value = command_bytes + + # Flight software receives and processes the large message + with patch("time.sleep"): + flight_software_cdh.listen_for_commands(timeout=30) + + # Verify HMAC authentication succeeded despite large message size + mock_packet_manager.send_acknowledgement.assert_called_once() + + # Verify counter was updated in NVM + mock_counter16.set.assert_called_once_with(gs_counter) + + # Verify the command was accepted and processed + # (In this case, send_joke would be called with the large args) + assert mock_packet_manager.send_acknowledgement.called diff --git a/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_jokes_config.py b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_jokes_config.py new file mode 100644 index 00000000..b00843d4 --- /dev/null +++ b/cpython-workspaces/flight-software-unit-tests/src/unit-tests/test_jokes_config.py @@ -0,0 +1,184 @@ +"""Integration tests for JokesConfig, ability to get the jokes out of the config. + +This module contains integration tests that verify the complete JokesConfig pipeline +Of getting, initializing, adding, saving and removing jokes +""" + +import json +import os +import tempfile + +import pytest +from pysquared.config.jokes_config import JokesConfig + + +@pytest.fixture() +def jokes_file_fixture(): + """Create a temporary jokes JSON file for testing.""" + jokes = [ + "Why did the chicken cross the road? To get to the other side!", + "Parallel lines have so much in common… it’s a shame they’ll never meet.", + ] + temp_dir = tempfile.mkdtemp() + + path = os.path.join(temp_dir, "jokes.test.json") + with open(path, "w") as f: + json.dump(jokes, f) + + yield path + + os.remove(path) + os.rmdir(temp_dir) + + +def test_load_jokes_success(jokes_file_fixture): + """Tests if you can load jokes + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + assert isinstance(config.jokes, list) + assert len(config.jokes) == 2 + assert all(isinstance(j, str) for j in config.jokes) + + +def test_invalid_jokes_file_format(): + """Tests what happens when invalid file for jokes""" + + with tempfile.NamedTemporaryFile("w", delete=False) as tmp: + tmp.write('{"not": "a list"}') + tmp_path = tmp.name + + with pytest.raises(ValueError): + JokesConfig(tmp_path) + + os.remove(tmp_path) + + +def test_validate_joke_valid(jokes_file_fixture): + """Tests validate joke function + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + config.validate_joke("This is a valid joke.") + + +@pytest.mark.parametrize("bad_joke", [None, 123, "", " " * 10, "A" * 501]) +def test_validate_joke_invalid(jokes_file_fixture, bad_joke): + """Tests Validate Bad Joke, this raises + Args: + bad joke: badly formatted json file . + """ + config = JokesConfig(jokes_file_fixture) + + with pytest.raises((TypeError, ValueError)): + config.validate_joke(bad_joke) + + +def test_add_joke_temporary(jokes_file_fixture): + """Testing adding a joke temporarily. This will add a joke to the list + but not the physical config file + + Args: + jokes_file_fixture: JSON file fixture. + """ + config = JokesConfig(jokes_file_fixture) + new_joke = "This is a temporary test joke." + config.add_joke(new_joke) + assert new_joke in config.jokes + + +def test_add_joke_permanent(jokes_file_fixture): + """This tests adding a joke permanently to the json object + Args: + jokes_file_fixture: JSON file fixture. + """ + config = JokesConfig(jokes_file_fixture) + new_joke = "This is a permanent test joke." + config.add_joke(new_joke, temporary=False) + + # Reload config to ensure it was saved + new_config = JokesConfig(jokes_file_fixture) + assert new_joke in new_config.jokes + + +def test_update_joke_temporary(jokes_file_fixture): + """This tests editing a specific joke using the index + Args: + jokes_file_fixture: JSON file fixture.""" + + config = JokesConfig(jokes_file_fixture) + updated_joke = "Updated joke temporarily" + config.update_joke(0, updated_joke) + assert config.jokes[0] == updated_joke + + +def test_update_joke_permanent(jokes_file_fixture): + """This tests editing a specific joke using the index and saving it permanently in the JSON + Args: + jokes_file_fixture: JSON file fixture.""" + + config = JokesConfig(jokes_file_fixture) + updated_joke = "Updated joke permanently" + config.update_joke(0, updated_joke, temporary=False) + + new_config = JokesConfig(jokes_file_fixture) + assert new_config.jokes[0] == updated_joke + + +def test_update_joke_invalid_index(jokes_file_fixture): + """This Test makes sure that indexing the JSON with an invalid index raises an Index Error + Args: + jokes_file_fixture: JSON file fixture. + """ + config = JokesConfig(jokes_file_fixture) + with pytest.raises(IndexError): + config.update_joke(100, "Doesn't matter") + + +def test_remove_joke_temporary(jokes_file_fixture): + """removing a joke from the JSON object + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + joke_to_remove = config.jokes[0] + config.remove_joke(0) + assert joke_to_remove not in config.jokes + + +def test_remove_joke_permanent(jokes_file_fixture): + """removing a joke from the JSON object permanatly + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + joke_to_remove = config.jokes[0] + config.remove_joke(0, temporary=False) + + new_config = JokesConfig(jokes_file_fixture) + assert joke_to_remove not in new_config.jokes + + +def test_remove_joke_invalid_index(jokes_file_fixture): + """removing a joke with an invalid index gives an IndexError + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + with pytest.raises(IndexError): + config.remove_joke(100) + + +def test_get_joke_valid(jokes_file_fixture): + """Gets a joke from the file fixture and ensures it is valid + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + joke = config.get_joke(0) + assert isinstance(joke, str) + + +def test_get_joke_invalid_index(jokes_file_fixture): + """removing a joke from an incorrect index and ensures it raises an IndexError + Args: + jokes_file_fixture: JSON file fixture.""" + config = JokesConfig(jokes_file_fixture) + with pytest.raises(IndexError): + config.get_joke(100) diff --git a/uv.lock b/uv.lock index 6c48f2c2..3b129170 100644 --- a/uv.lock +++ b/uv.lock @@ -388,6 +388,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] +[[package]] +name = "circuitpython-hmac" +version = "0.2.2.dev2+g2944cbf9a" +source = { git = "https://github.com/jimbobbennett/CircuitPython_HMAC.git#2944cbf9a863bf47d2dbc6ade4f89226b179c356" } +dependencies = [ + { name = "adafruit-blinka" }, + { name = "adafruit-circuitpython-hashlib" }, +] + [[package]] name = "click" version = "8.2.1" @@ -1071,6 +1080,7 @@ dependencies = [ { name = "adafruit-circuitpython-tca9548a" }, { name = "adafruit-circuitpython-ticks" }, { name = "adafruit-circuitpython-veml7700" }, + { name = "circuitpython-hmac" }, { name = "proves-circuitpython-sx126" }, { name = "proves-circuitpython-sx1280" }, ] @@ -1090,6 +1100,7 @@ requires-dist = [ { name = "adafruit-circuitpython-tca9548a", git = "https://github.com/proveskit/Adafruit_CircuitPython_TCA9548A?rev=1.1.0" }, { name = "adafruit-circuitpython-ticks", specifier = "==1.1.1" }, { name = "adafruit-circuitpython-veml7700", specifier = "==2.1.4" }, + { name = "circuitpython-hmac", git = "https://github.com/jimbobbennett/CircuitPython_HMAC.git" }, { name = "proves-circuitpython-sx126", git = "https://github.com/proveskit/micropySX126X?rev=1.0.0" }, { name = "proves-circuitpython-sx1280", git = "https://github.com/proveskit/CircuitPython_SX1280?rev=1.0.4" }, ]