Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions scripts/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import requests
import signal
import argparse
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, NoCredentialsError
from typing import Dict
import sys
import time
import yaml

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from confidential_compute import ConfidentialCompute, ConfidentialComputeConfig, SecretNotFoundException, ConfidentialComputeStartupException
from confidential_compute import ConfidentialCompute, ConfidentialComputeConfig, MissingInstanceProfile, ConfigNotFound, InvalidConfigValue, ConfidentialComputeStartupException

class AWSConfidentialComputeConfig(ConfidentialComputeConfig):
enclave_memory_mb: int
Expand Down Expand Up @@ -61,7 +61,7 @@ def __get_aws_token(self) -> str:
)
return response.text
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch aws token: {e}")
raise RuntimeError(f"Failed to fetch AWS token: {e}")

def __get_current_region(self) -> str:
"""Fetches the current AWS region from EC2 instance metadata."""
Expand Down Expand Up @@ -97,16 +97,15 @@ def add_defaults(configs: Dict[str, any]) -> AWSConfidentialComputeConfig:

region = self.__get_current_region()
print(f"Running in {region}")
try:
client = boto3.client("secretsmanager", region_name=region)
except Exception as e:
raise RuntimeError("Please use IAM instance profile for your instance and make sure that has permission to access Secret Manager", e)
client = boto3.client("secretsmanager", region_name=region)
try:
secret = add_defaults(json.loads(client.get_secret_value(SecretId=secret_identifier)["SecretString"]))
self.__validate_aws_specific_config(secret)
return secret
except NoCredentialsError as _:
raise MissingInstanceProfile(self.__class__.__name__)
except ClientError as _:
raise SecretNotFoundException(f"{secret_identifier} in {region}")
raise ConfigNotFound(self.__class__.__name__, f"Secret Manager {secret_identifier} in {region}")

@staticmethod
def __get_max_capacity():
Expand Down Expand Up @@ -255,5 +254,5 @@ def __kill_auxiliaries(self) -> None:
except ConfidentialComputeStartupException as e:
print("Failed starting up Confidential Compute. Please checks the logs for errors and retry \n", e)
except Exception as e:
print("Unknown failure while starting up Confidential Compute. Please contact UID support team with this log \n ", e)
print("Unexpected failure while starting up Confidential Compute. Please contact UID support team with this log \n ", e)

86 changes: 45 additions & 41 deletions scripts/confidential_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@ class ConfidentialComputeConfig(TypedDict):
environment: str
skip_validations: NotRequired[bool]
debug_mode: NotRequired[bool]

class ConfidentialComputeStartupException(Exception):
def __init__(self, error_name, provider, extra_message=None):
urls = {
"EC2": "https://unifiedid.com/docs/guides/operator-guide-aws-marketplace#uid2-operator-error-codes",
"Azure": "https://unifiedid.com/docs/guides/operator-guide-azure-enclave#uid2-operator-error-codes",
"GCP": "https://unifiedid.com/docs/guides/operator-private-gcp-confidential-space#uid2-operator-error-codes",
}
url = urls.get(provider)
super().__init__(f"{error_name}\n" + (extra_message if extra_message else "") + f"\nVisit {url} for more details")

class MissingInstanceProfile(ConfidentialComputeStartupException):
def __init__(self, cls):
super().__init__(error_name=f"E01: {self.__class__.__name__}", provider=cls)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if you have tried @abuabraham-ttd - the URL may be anchored to the specific error code like the glossary page, e.g. https://unifiedid.com/docs/ref-info/glossary-uid#gl-client-key

maybe like:
https://unifiedid.com/docs/guides/operator-guide-aws-marketplace#uid2-operator-error-codes#gl-e01

so people can just see the right error code straight away


class ConfigNotFound(ConfidentialComputeStartupException):
def __init__(self, cls, message = None):
super().__init__(error_name=f"E02: {self.__class__.__name__}", provider=cls, extra_message=message)

class MissingConfig(ConfidentialComputeStartupException):
def __init__(self, cls, missing_keys):
super().__init__(error_name=f"E03: {self.__class__.__name__}", provider=cls, extra_message=', '.join(missing_keys))

class InvalidConfigValue(ConfidentialComputeStartupException):
def __init__(self, cls, config_key = None):
super().__init__(error_name=f"E04: {self.__class__.__name__} " , provider=cls, extra_message=config_key)

class InvalidOperatorKey(ConfidentialComputeStartupException):
def __init__(self, cls):
super().__init__(error_name=f"E05: {self.__class__.__name__}", provider=cls)

class UID2ServicesUnreachable(ConfidentialComputeStartupException):
def __init__(self, cls, ip=None):
super().__init__(error_name=f"E06: {self.__class__.__name__}", provider=cls, extra_message=ip)

class ConfidentialCompute(ABC):

Expand All @@ -25,36 +59,26 @@ def validate_configuration(self):
def validate_operator_key():
""" Validates the operator key format and its environment alignment."""
operator_key = self.configs.get("api_token")
if not operator_key:
raise ValueError("API token is missing from the configuration.")
pattern = r"^(UID2|EUID)-.\-(I|P|L)-\d+-.*$"
if re.match(pattern, operator_key):
env = self.configs.get("environment", "").lower()
debug_mode = self.configs.get("debug_mode", False)
expected_env = "I" if debug_mode or env == "integ" else "P"

if operator_key.split("-")[2] != expected_env:
raise ValueError(
f"Operator key does not match the expected environment ({expected_env})."
)
raise InvalidOperatorKey(self.__class__.__name__)
print("Validated operator key matches environment")
else:
print("Skipping operator key validation")

def validate_url(url_key, environment):
"""URL should include environment except in prod"""
if environment != "prod" and environment not in self.configs[url_key]:
raise ValueError(
f"{url_key} must match the environment. Ensure the URL includes '{environment}'."
)
raise InvalidConfigValue(self.__class__.__name__, url_key)
parsed_url = urlparse(self.configs[url_key])
if parsed_url.scheme != 'https' and parsed_url.path:
raise ValueError(
f"{url_key} is invalid. Ensure {self.configs[url_key]} follows HTTPS, and doesn't have any path specified."
)
raise InvalidConfigValue(self.__class__.__name__, url_key)
print(f"Validated {self.configs[url_key]} matches other config parameters")


def validate_connectivity() -> None:
""" Validates that the core URL is accessible."""
try:
Expand All @@ -63,32 +87,29 @@ def validate_connectivity() -> None:
requests.get(core_url, timeout=5)
print(f"Validated connectivity to {core_url}")
except (requests.ConnectionError, requests.Timeout) as e:
raise RuntimeError(
f"Failed to reach required URLs. Consider enabling {core_ip} in the egress firewall."
)
raise UID2ServicesUnreachable(self.__class__.__name__, core_ip)
except Exception as e:
raise Exception("Failed to reach the URLs.") from e
raise UID2ServicesUnreachable(self.__class__.__name__)

type_hints = get_type_hints(ConfidentialComputeConfig, include_extras=True)
required_keys = [field for field, hint in type_hints.items() if "NotRequired" not in str(hint)]
missing_keys = [key for key in required_keys if key not in self.configs]
if missing_keys:
raise MissingConfigError(missing_keys)
raise MissingConfig(self.__class__.__name__, missing_keys)

environment = self.configs["environment"]

if environment not in ["integ", "prod"]:
raise ValueError("Environment must be either prod/integ. It is currently set to", environment)
raise InvalidConfigValue(self.__class__.__name__, "environment")

if self.configs.get("debug_mode") and environment == "prod":
raise ValueError("Debug mode cannot be enabled in the production environment.")
raise InvalidConfigValue(self.__class__.__name__, "debug_mode")

validate_url("core_base_url", environment)
validate_url("optout_base_url", environment)
validate_operator_key()
validate_connectivity()
print("Completed static validation of confidential compute config values")


@abstractmethod
def _get_secret(self, secret_identifier: str) -> ConfidentialComputeConfig:
"""
Expand Down Expand Up @@ -124,21 +145,4 @@ def run_command(command, seperate_process=False):
subprocess.run(command,check=True)
except Exception as e:
print(f"Failed to run command: {str(e)}")
raise RuntimeError (f"Failed to start {' '.join(command)} ")

class ConfidentialComputeStartupException(Exception):
def __init__(self, message):
super().__init__(message)

class MissingConfigError(ConfidentialComputeStartupException):
"""Custom exception to handle missing config keys."""
def __init__(self, missing_keys):
self.missing_keys = missing_keys
self.message = f"\n Missing configuration keys: {', '.join(missing_keys)} \n"
super().__init__(self.message)

class SecretNotFoundException(ConfidentialComputeStartupException):
"""Custom exception if secret manager is not found"""
def __init__(self, name):
self.message = f"Secret manager not found - {name}. Please check if secret exist and the Instance Profile has permission to read it"
super().__init__(self.message)
raise RuntimeError (f"Failed to start {' '.join(command)} ")
Loading