Skip to content

Commit 1daba54

Browse files
Addressing feedback-2
1 parent 6929f3a commit 1daba54

File tree

1 file changed

+48
-48
lines changed

1 file changed

+48
-48
lines changed

scripts/confidential_compute.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import socket
44
from urllib.parse import urlparse
55
from abc import ABC, abstractmethod
6-
from typing import TypedDict, Dict
6+
from typing import TypedDict
7+
78

89
class OperatorConfig(TypedDict):
910
enclave_memory_mb: int
@@ -13,70 +14,69 @@ class OperatorConfig(TypedDict):
1314
core_base_url: str
1415
optout_base_url: str
1516

16-
class ConfidentialCompute(ABC):
1717

18+
class ConfidentialCompute(ABC):
1819
@abstractmethod
19-
def _get_secret(self, secret_identifier):
20+
def _get_secret(self, secret_identifier: str) -> OperatorConfig:
2021
"""
21-
Gets the secret from secret store
22+
Fetches the secret from a secret store.
2223
23-
Raises:
24-
SecretNotFoundException: Points to public documentation
24+
Raises:
25+
SecretNotFoundException: If the secret is not found.
2526
"""
2627
pass
2728

28-
def validate_operator_key(self, secrets: OperatorConfig):
29-
"""
30-
Validates operator key if following new pattern. Ignores otherwise
31-
"""
32-
api_token = secrets.get('api_token', None)
29+
def validate_operator_key(self, secrets: OperatorConfig) -> bool:
30+
""" Validates the operator key format and its environment alignment."""
31+
api_token = secrets.get("api_token")
32+
if not api_token:
33+
raise ValueError("API token is missing from the configuration.")
34+
3335
pattern = r"^(UID2|EUID)-.\-(I|P)-\d+-\*$"
34-
if bool(re.match(pattern, api_token)):
35-
if secrets.get('debug_mode', False) or secrets.get('environment') == 'integ':
36-
if api_token.split('-')[2] != 'I':
37-
raise Exception("Operator key does not match the environment")
38-
else:
39-
if api_token.split('-')[2] != 'P':
40-
raise Exception("Operator key does not match the environment")
36+
if re.match(pattern, api_token):
37+
env = secrets.get("environment", "").lower()
38+
debug_mode = secrets.get("debug_mode", False)
39+
expected_env = "I" if debug_mode or env == "integ" else "P"
40+
if api_token.split("-")[2] != expected_env:
41+
raise ValueError(
42+
f"Operator key does not match the expected environment ({expected_env})."
43+
)
4144
return True
45+
46+
@staticmethod
47+
def __resolve_hostname(url: str) -> str:
48+
""" Resolves the hostname of a URL to an IP address."""
49+
hostname = urlparse(url).netloc
50+
return socket.gethostbyname(hostname)
4251

43-
def validate_connectivity(self, config: OperatorConfig):
44-
"""
45-
Validates core/optout is accessible.
46-
"""
52+
def validate_connectivity(self, config: OperatorConfig) -> None:
53+
""" Validates that the core and opt-out URLs are accessible."""
4754
try:
48-
core_ip = socket.gethostbyname(urlparse(config['core_base_url']).netloc)
49-
requests.get(config['core_base_url'], timeout=5)
50-
optout_ip = socket.gethostbyname(urlparse(config['optout_base_url']).netloc)
51-
requests.get(config['optout_base_url'], timeout=5)
52-
except (requests.ConnectionError, requests.Timeout) as e :
53-
raise Exception("Failed to reach the URL. -- ERROR CODE, enable IPs? {} {}".format(core_ip, optout_ip), e)
55+
core_url = config["core_base_url"]
56+
optout_url = config["optout_base_url"]
57+
core_ip = self.__resolve_hostname(core_url)
58+
requests.get(core_url, timeout=5)
59+
optout_ip = self.__resolve_hostname(optout_url)
60+
requests.get(optout_url, timeout=5)
61+
62+
except (requests.ConnectionError, requests.Timeout) as e:
63+
raise Exception(
64+
f"Failed to reach required URLs. Consider enabling {core_ip}, {optout_ip} in the egress firewall."
65+
)
5466
except Exception as e:
55-
raise Exception("Failed to reach the URL. ")
56-
"""
57-
s3 does not have static IP, and the range returned for s3 is huge to validate.
58-
r = requests.get('https://ip-ranges.amazonaws.com/ip-ranges.json')
59-
ips = list(map(lambda x: x['ip_prefix'], filter(lambda x: x['region']=='us-east-1' and x['service'] == 'S3', r.json()['prefixes'])))
60-
"""
61-
return
62-
67+
raise Exception("Failed to reach the URLs.") from e
68+
6369
@abstractmethod
64-
def _setup_auxiliaries(self):
65-
"""
66-
Sets up auxilary processes required to confidential compute
67-
"""
70+
def _setup_auxiliaries(self) -> None:
71+
""" Sets up auxiliary processes required for confidential computing. """
6872
pass
6973

7074
@abstractmethod
71-
def _validate_auxiliaries(self):
72-
"""
73-
Validates auxilary services are running
74-
"""
75+
def _validate_auxiliaries(self) -> None:
76+
""" Validates auxiliary services are running."""
7577
pass
7678

7779
@abstractmethod
78-
def run_compute(self):
79-
"""
80-
Runs compute.
81-
"""
80+
def run_compute(self) -> None:
81+
""" Runs confidential computing."""
8282
pass

0 commit comments

Comments
 (0)