diff --git a/scripts/aws/ec2.py b/scripts/aws/ec2.py index 69fed7f91..ee47f4556 100644 --- a/scripts/aws/ec2.py +++ b/scripts/aws/ec2.py @@ -9,14 +9,14 @@ import requests import signal import argparse -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, NoCredentialsError from typing import Dict import sys import time import yaml sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from confidential_compute import ConfidentialCompute, ConfidentialComputeConfig, SecretNotFoundException, ConfidentialComputeStartupException +from confidential_compute import ConfidentialCompute, ConfidentialComputeConfig, MissingInstanceProfile, ConfigNotFound, InvalidConfigValue, ConfidentialComputeStartupException class AWSConfidentialComputeConfig(ConfidentialComputeConfig): enclave_memory_mb: int @@ -61,7 +61,7 @@ def __get_aws_token(self) -> str: ) return response.text except requests.RequestException as e: - raise RuntimeError(f"Failed to fetch aws token: {e}") + raise RuntimeError(f"Failed to fetch AWS token: {e}") def __get_current_region(self) -> str: """Fetches the current AWS region from EC2 instance metadata.""" @@ -97,16 +97,15 @@ def add_defaults(configs: Dict[str, any]) -> AWSConfidentialComputeConfig: region = self.__get_current_region() print(f"Running in {region}") - try: - client = boto3.client("secretsmanager", region_name=region) - except Exception as e: - raise RuntimeError("Please use IAM instance profile for your instance and make sure that has permission to access Secret Manager", e) + client = boto3.client("secretsmanager", region_name=region) try: secret = add_defaults(json.loads(client.get_secret_value(SecretId=secret_identifier)["SecretString"])) self.__validate_aws_specific_config(secret) return secret + except NoCredentialsError as _: + raise MissingInstanceProfile(self.__class__.__name__) except ClientError as _: - raise SecretNotFoundException(f"{secret_identifier} in {region}") + raise ConfigNotFound(self.__class__.__name__, f"Secret Manager {secret_identifier} in {region}") @staticmethod def __get_max_capacity(): @@ -255,5 +254,5 @@ def __kill_auxiliaries(self) -> None: except ConfidentialComputeStartupException as e: print("Failed starting up Confidential Compute. Please checks the logs for errors and retry \n", e) except Exception as e: - print("Unknown failure while starting up Confidential Compute. Please contact UID support team with this log \n ", e) + print("Unexpected failure while starting up Confidential Compute. Please contact UID support team with this log \n ", e) \ No newline at end of file diff --git a/scripts/confidential_compute.py b/scripts/confidential_compute.py index 306474e9c..4c80be659 100644 --- a/scripts/confidential_compute.py +++ b/scripts/confidential_compute.py @@ -13,6 +13,40 @@ class ConfidentialComputeConfig(TypedDict): environment: str skip_validations: NotRequired[bool] debug_mode: NotRequired[bool] + +class ConfidentialComputeStartupException(Exception): + def __init__(self, error_name, provider, extra_message=None): + urls = { + "EC2": "https://unifiedid.com/docs/guides/operator-guide-aws-marketplace#uid2-operator-error-codes", + "Azure": "https://unifiedid.com/docs/guides/operator-guide-azure-enclave#uid2-operator-error-codes", + "GCP": "https://unifiedid.com/docs/guides/operator-private-gcp-confidential-space#uid2-operator-error-codes", + } + url = urls.get(provider) + super().__init__(f"{error_name}\n" + (extra_message if extra_message else "") + f"\nVisit {url} for more details") + +class MissingInstanceProfile(ConfidentialComputeStartupException): + def __init__(self, cls): + super().__init__(error_name=f"E01: {self.__class__.__name__}", provider=cls) + +class ConfigNotFound(ConfidentialComputeStartupException): + def __init__(self, cls, message = None): + super().__init__(error_name=f"E02: {self.__class__.__name__}", provider=cls, extra_message=message) + +class MissingConfig(ConfidentialComputeStartupException): + def __init__(self, cls, missing_keys): + super().__init__(error_name=f"E03: {self.__class__.__name__}", provider=cls, extra_message=', '.join(missing_keys)) + +class InvalidConfigValue(ConfidentialComputeStartupException): + def __init__(self, cls, config_key = None): + super().__init__(error_name=f"E04: {self.__class__.__name__} " , provider=cls, extra_message=config_key) + +class InvalidOperatorKey(ConfidentialComputeStartupException): + def __init__(self, cls): + super().__init__(error_name=f"E05: {self.__class__.__name__}", provider=cls) + +class UID2ServicesUnreachable(ConfidentialComputeStartupException): + def __init__(self, cls, ip=None): + super().__init__(error_name=f"E06: {self.__class__.__name__}", provider=cls, extra_message=ip) class ConfidentialCompute(ABC): @@ -25,18 +59,13 @@ def validate_configuration(self): def validate_operator_key(): """ Validates the operator key format and its environment alignment.""" operator_key = self.configs.get("api_token") - if not operator_key: - raise ValueError("API token is missing from the configuration.") pattern = r"^(UID2|EUID)-.\-(I|P|L)-\d+-.*$" if re.match(pattern, operator_key): env = self.configs.get("environment", "").lower() debug_mode = self.configs.get("debug_mode", False) expected_env = "I" if debug_mode or env == "integ" else "P" - if operator_key.split("-")[2] != expected_env: - raise ValueError( - f"Operator key does not match the expected environment ({expected_env})." - ) + raise InvalidOperatorKey(self.__class__.__name__) print("Validated operator key matches environment") else: print("Skipping operator key validation") @@ -44,17 +73,12 @@ def validate_operator_key(): def validate_url(url_key, environment): """URL should include environment except in prod""" if environment != "prod" and environment not in self.configs[url_key]: - raise ValueError( - f"{url_key} must match the environment. Ensure the URL includes '{environment}'." - ) + raise InvalidConfigValue(self.__class__.__name__, url_key) parsed_url = urlparse(self.configs[url_key]) if parsed_url.scheme != 'https' and parsed_url.path: - raise ValueError( - f"{url_key} is invalid. Ensure {self.configs[url_key]} follows HTTPS, and doesn't have any path specified." - ) + raise InvalidConfigValue(self.__class__.__name__, url_key) print(f"Validated {self.configs[url_key]} matches other config parameters") - def validate_connectivity() -> None: """ Validates that the core URL is accessible.""" try: @@ -63,24 +87,22 @@ def validate_connectivity() -> None: requests.get(core_url, timeout=5) print(f"Validated connectivity to {core_url}") except (requests.ConnectionError, requests.Timeout) as e: - raise RuntimeError( - f"Failed to reach required URLs. Consider enabling {core_ip} in the egress firewall." - ) + raise UID2ServicesUnreachable(self.__class__.__name__, core_ip) except Exception as e: - raise Exception("Failed to reach the URLs.") from e + raise UID2ServicesUnreachable(self.__class__.__name__) + type_hints = get_type_hints(ConfidentialComputeConfig, include_extras=True) required_keys = [field for field, hint in type_hints.items() if "NotRequired" not in str(hint)] missing_keys = [key for key in required_keys if key not in self.configs] if missing_keys: - raise MissingConfigError(missing_keys) - + raise MissingConfig(self.__class__.__name__, missing_keys) + environment = self.configs["environment"] - if environment not in ["integ", "prod"]: - raise ValueError("Environment must be either prod/integ. It is currently set to", environment) + raise InvalidConfigValue(self.__class__.__name__, "environment") if self.configs.get("debug_mode") and environment == "prod": - raise ValueError("Debug mode cannot be enabled in the production environment.") + raise InvalidConfigValue(self.__class__.__name__, "debug_mode") validate_url("core_base_url", environment) validate_url("optout_base_url", environment) @@ -88,7 +110,6 @@ def validate_connectivity() -> None: validate_connectivity() print("Completed static validation of confidential compute config values") - @abstractmethod def _get_secret(self, secret_identifier: str) -> ConfidentialComputeConfig: """ @@ -124,21 +145,4 @@ def run_command(command, seperate_process=False): subprocess.run(command,check=True) except Exception as e: print(f"Failed to run command: {str(e)}") - raise RuntimeError (f"Failed to start {' '.join(command)} ") - -class ConfidentialComputeStartupException(Exception): - def __init__(self, message): - super().__init__(message) - -class MissingConfigError(ConfidentialComputeStartupException): - """Custom exception to handle missing config keys.""" - def __init__(self, missing_keys): - self.missing_keys = missing_keys - self.message = f"\n Missing configuration keys: {', '.join(missing_keys)} \n" - super().__init__(self.message) - -class SecretNotFoundException(ConfidentialComputeStartupException): - """Custom exception if secret manager is not found""" - def __init__(self, name): - self.message = f"Secret manager not found - {name}. Please check if secret exist and the Instance Profile has permission to read it" - super().__init__(self.message) + raise RuntimeError (f"Failed to start {' '.join(command)} ") \ No newline at end of file