Skip to content

Commit d98a7ad

Browse files
Refactor event_gate_lambda and writer modules for improved type hinting, logging, and error handling; enhance readability and maintainability
1 parent 971c38b commit d98a7ad

File tree

4 files changed

+251
-181
lines changed

4 files changed

+251
-181
lines changed

src/event_gate_lambda.py

Lines changed: 100 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import logging
1919
import os
2020
import sys
21+
from typing import Any, Dict
22+
2123
import urllib3
22-
from typing import Any
2324

2425
import boto3
2526
import jwt
@@ -28,52 +29,57 @@
2829
from jsonschema import validate
2930
from jsonschema.exceptions import ValidationError
3031

31-
# Resolve project root (parent directory of this file's directory)
32-
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
33-
_CONF_DIR = os.path.join(_PROJECT_ROOT, "conf")
34-
3532
from . import writer_eventbridge, writer_kafka, writer_postgres
3633

3734
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
3835

3936
logger = logging.getLogger(__name__)
4037
log_level = os.environ.get("LOG_LEVEL", "INFO")
4138
logger.setLevel(log_level)
42-
logger.addHandler(logging.StreamHandler())
39+
if not logger.handlers:
40+
logger.addHandler(logging.StreamHandler())
4341
logger.debug("Initialized LOGGER")
4442

45-
with open(os.path.join(_CONF_DIR, "api.yaml"), "r") as file:
43+
# Resolve project root (parent directory of this file's directory)
44+
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
45+
_CONF_DIR = os.path.join(_PROJECT_ROOT, "conf")
46+
47+
with open(os.path.join(_CONF_DIR, "api.yaml"), "r", encoding="utf-8") as file:
4648
API = file.read()
4749
logger.debug("Loaded API definition")
4850

49-
TOPICS = {}
50-
with open(os.path.join(_CONF_DIR, "topic_runs.json"), "r") as file:
51+
TOPICS: Dict[str, Dict[str, Any]] = {}
52+
with open(os.path.join(_CONF_DIR, "topic_runs.json"), "r", encoding="utf-8") as file:
5153
TOPICS["public.cps.za.runs"] = json.load(file)
52-
with open(os.path.join(_CONF_DIR, "topic_dlchange.json"), "r") as file:
54+
with open(os.path.join(_CONF_DIR, "topic_dlchange.json"), "r", encoding="utf-8") as file:
5355
TOPICS["public.cps.za.dlchange"] = json.load(file)
54-
with open(os.path.join(_CONF_DIR, "topic_test.json"), "r") as file:
56+
with open(os.path.join(_CONF_DIR, "topic_test.json"), "r", encoding="utf-8") as file:
5557
TOPICS["public.cps.za.test"] = json.load(file)
5658
logger.debug("Loaded TOPICS")
5759

58-
with open(os.path.join(_CONF_DIR, "config.json"), "r") as file:
60+
with open(os.path.join(_CONF_DIR, "config.json"), "r", encoding="utf-8") as file:
5961
CONFIG = json.load(file)
6062
logger.debug("Loaded main CONFIG")
6163

62-
aws_s3 = boto3.Session().resource("s3", verify=False)
64+
aws_s3 = boto3.Session().resource("s3", verify=False) # nosec Boto verify disabled intentionally
6365
logger.debug("Initialized AWS S3 Client")
6466

6567
if CONFIG["access_config"].startswith("s3://"):
6668
name_parts = CONFIG["access_config"].split("/")
67-
bucket_name = name_parts[2]
68-
bucket_object = "/".join(name_parts[3:])
69-
ACCESS = json.loads(aws_s3.Bucket(bucket_name).Object(bucket_object).get()["Body"].read().decode("utf-8"))
69+
BUCKET_NAME = name_parts[2]
70+
BUCKET_OBJECT_KEY = "/".join(name_parts[3:])
71+
ACCESS = json.loads(
72+
aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8")
73+
)
7074
else:
71-
with open(CONFIG["access_config"], "r") as file:
75+
with open(CONFIG["access_config"], "r", encoding="utf-8") as file:
7276
ACCESS = json.load(file)
7377
logger.debug("Loaded ACCESS definitions")
7478

7579
TOKEN_PROVIDER_URL = CONFIG["token_provider_url"]
76-
token_public_key_encoded = requests.get(CONFIG["token_public_key_url"], verify=False).json()["key"]
80+
# Add timeout to avoid hanging requests
81+
response_json = requests.get(CONFIG["token_public_key_url"], verify=False, timeout=5).json() # nosec external
82+
token_public_key_encoded = response_json["key"]
7783
TOKEN_PUBLIC_KEY: Any = serialization.load_der_public_key(base64.b64decode(token_public_key_encoded))
7884
logger.debug("Loaded TOKEN_PUBLIC_KEY")
7985

@@ -82,7 +88,16 @@
8288
writer_postgres.init(logger)
8389

8490

85-
def _error_response(status, err_type, message):
91+
def _error_response(status: int, err_type: str, message: str) -> Dict[str, Any]:
92+
"""Build a standardized JSON error response body.
93+
94+
Args:
95+
status: HTTP status code.
96+
err_type: A short error classifier (e.g. 'auth', 'validation').
97+
message: Human readable error description.
98+
Returns:
99+
A dictionary compatible with API Gateway Lambda Proxy integration.
100+
"""
86101
return {
87102
"statusCode": status,
88103
"headers": {"Content-Type": "application/json"},
@@ -92,55 +107,69 @@ def _error_response(status, err_type, message):
92107
}
93108

94109

95-
def get_api():
110+
def get_api() -> Dict[str, Any]:
111+
"""Return the OpenAPI specification text."""
96112
return {"statusCode": 200, "body": API}
97113

98114

99-
def get_token():
115+
def get_token() -> Dict[str, Any]:
116+
"""Return 303 redirect to token provider endpoint."""
100117
logger.debug("Handling GET Token")
101118
return {"statusCode": 303, "headers": {"Location": TOKEN_PROVIDER_URL}}
102119

103120

104-
def get_topics():
121+
def get_topics() -> Dict[str, Any]:
122+
"""Return list of available topic names."""
105123
logger.debug("Handling GET Topics")
106124
return {
107125
"statusCode": 200,
108126
"headers": {"Content-Type": "application/json"},
109-
"body": json.dumps([topicName for topicName in TOPICS]),
127+
"body": json.dumps(list(TOPICS)),
110128
}
111129

112130

113-
def get_topic_schema(topicName):
114-
logger.debug(f"Handling GET TopicSchema({topicName})")
115-
if topicName not in TOPICS:
116-
return _error_response(404, "topic", f"Topic '{topicName}' not found")
131+
def get_topic_schema(topic_name: str) -> Dict[str, Any]:
132+
"""Return the JSON schema for a specific topic.
133+
134+
Args:
135+
topic_name: The topic whose schema is requested.
136+
"""
137+
logger.debug("Handling GET TopicSchema(%s)", topic_name)
138+
if topic_name not in TOPICS:
139+
return _error_response(404, "topic", f"Topic '{topic_name}' not found")
140+
141+
return {"statusCode": 200, "headers": {"Content-Type": "application/json"}, "body": json.dumps(TOPICS[topic_name])}
117142

118-
return {"statusCode": 200, "headers": {"Content-Type": "application/json"}, "body": json.dumps(TOPICS[topicName])}
119143

144+
def post_topic_message(topic_name: str, topic_message: Dict[str, Any], token_encoded: str) -> Dict[str, Any]:
145+
"""Validate auth and schema; dispatch message to all writers.
120146
121-
def post_topic_message(topicName, topicMessage, tokenEncoded):
122-
logger.debug(f"Handling POST {topicName}")
147+
Args:
148+
topic_name: Target topic name.
149+
topic_message: JSON message payload.
150+
token_encoded: Encoded bearer JWT token string.
151+
"""
152+
logger.debug("Handling POST %s", topic_name)
123153
try:
124-
token = jwt.decode(tokenEncoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"])
125-
except Exception:
154+
token = jwt.decode(token_encoded, TOKEN_PUBLIC_KEY, algorithms=["RS256"]) # type: ignore[arg-type]
155+
except jwt.PyJWTError: # type: ignore[attr-defined]
126156
return _error_response(401, "auth", "Invalid or missing token")
127157

128-
if topicName not in TOPICS:
129-
return _error_response(404, "topic", f"Topic '{topicName}' not found")
158+
if topic_name not in TOPICS:
159+
return _error_response(404, "topic", f"Topic '{topic_name}' not found")
130160

131-
user = token["sub"]
132-
if topicName not in ACCESS or user not in ACCESS[topicName]:
161+
user = token.get("sub")
162+
if topic_name not in ACCESS or user not in ACCESS[topic_name]: # type: ignore[index]
133163
return _error_response(403, "auth", "User not authorized for topic")
134164

135165
try:
136-
validate(instance=topicMessage, schema=TOPICS[topicName])
137-
except ValidationError as e:
138-
return _error_response(400, "validation", e.message)
166+
validate(instance=topic_message, schema=TOPICS[topic_name])
167+
except ValidationError as exc:
168+
return _error_response(400, "validation", exc.message)
139169

140-
# Run all writers independently (avoid short-circuit so failures in one don't skip others)
141-
kafka_ok, kafka_err = writer_kafka.write(topicName, topicMessage)
142-
eventbridge_ok, eventbridge_err = writer_eventbridge.write(topicName, topicMessage)
143-
postgres_ok, postgres_err = writer_postgres.write(topicName, topicMessage)
170+
kafka_ok, kafka_err = writer_kafka.write(topic_name, topic_message)
171+
eventbridge_ok, eventbridge_err = writer_eventbridge.write(topic_name, topic_message)
172+
postgres_ok, postgres_err = writer_postgres.write(topic_name, topic_message)
144173

145174
errors = []
146175
if not kafka_ok:
@@ -164,37 +193,46 @@ def post_topic_message(topicName, topicMessage, tokenEncoded):
164193
}
165194

166195

167-
def extract_token(eventHeaders):
168-
# Initial implementation used bearer header directly
169-
if "bearer" in eventHeaders:
170-
return eventHeaders["bearer"]
196+
def extract_token(event_headers: Dict[str, str]) -> str:
197+
"""Extract bearer token from headers.
171198
172-
if "Authorization" in eventHeaders and eventHeaders["Authorization"].startswith("Bearer "):
173-
return eventHeaders["Authorization"][len("Bearer ") :]
199+
Supports lowercase custom 'bearer' header or standard 'Authorization: Bearer <token>'.
200+
Returns empty string if not present (caller handles auth error response).
201+
"""
202+
if "bearer" in event_headers:
203+
return event_headers["bearer"]
204+
auth_header = event_headers.get("Authorization", "")
205+
if auth_header.startswith("Bearer "):
206+
return auth_header[len("Bearer ") :]
207+
return ""
174208

175-
return "" # Will result in 401
176209

210+
def lambda_handler(event: Dict[str, Any], context: Any): # pylint: disable=unused-argument,too-many-return-statements
211+
"""AWS Lambda entry point.
177212
178-
def lambda_handler(event, context):
213+
Dispatches based on API Gateway proxy 'resource' and 'httpMethod'.
214+
"""
179215
try:
180-
if event["resource"].lower() == "/api":
216+
resource = event.get("resource", "").lower()
217+
if resource == "/api":
181218
return get_api()
182-
if event["resource"].lower() == "/token":
219+
if resource == "/token":
183220
return get_token()
184-
if event["resource"].lower() == "/topics":
221+
if resource == "/topics":
185222
return get_topics()
186-
if event["resource"].lower() == "/topics/{topic_name}":
187-
if event["httpMethod"] == "GET":
223+
if resource == "/topics/{topic_name}":
224+
method = event.get("httpMethod")
225+
if method == "GET":
188226
return get_topic_schema(event["pathParameters"]["topic_name"].lower())
189-
if event["httpMethod"] == "POST":
227+
if method == "POST":
190228
return post_topic_message(
191229
event["pathParameters"]["topic_name"].lower(),
192230
json.loads(event["body"]),
193-
extract_token(event["headers"]),
231+
extract_token(event.get("headers", {})),
194232
)
195-
if event["resource"].lower() == "/terminate":
196-
sys.exit("TERMINATING")
233+
if resource == "/terminate":
234+
sys.exit("TERMINATING") # pragma: no cover - deliberate termination path
197235
return _error_response(404, "route", "Resource not found")
198-
except Exception as e:
199-
logger.error(f"Unexpected exception: {e}")
236+
except Exception as exc: # pylint: disable=broad-exception-caught
237+
logger.error("Unexpected exception: %s", exc)
200238
return _error_response(500, "internal", "Unexpected server error")

src/writer_eventbridge.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,75 @@
1-
#
2-
# Copyright 2025 ABSA Group Limited
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
#
1+
"""EventBridge writer module.
2+
3+
Provides initialization and write functionality for publishing events to AWS EventBridge.
4+
"""
5+
166
import json
177
import logging
18-
from typing import Optional, Tuple
8+
from typing import Any, Dict, Optional, Tuple
199

2010
import boto3
11+
from botocore.exceptions import BotoCoreError, ClientError
2112

22-
# Module globals for typing
23-
_logger: logging.Logger = logging.getLogger(__name__)
24-
EVENT_BUS_ARN: str = ""
25-
aws_eventbridge = None # will hold boto3 client
13+
STATE: Dict[str, Any] = {"logger": logging.getLogger(__name__), "event_bus_arn": "", "client": None}
2614

2715

28-
def init(logger, CONFIG):
29-
global _logger
30-
global EVENT_BUS_ARN
31-
global aws_eventbridge
16+
def init(logger: logging.Logger, config: Dict[str, Any]) -> None:
17+
"""Initialize the EventBridge writer.
3218
33-
_logger = logger
19+
Args:
20+
logger: Shared application logger.
21+
config: Configuration dictionary (expects optional 'event_bus_arn').
22+
"""
23+
STATE["logger"] = logger
24+
STATE["client"] = boto3.client("events")
25+
STATE["event_bus_arn"] = config.get("event_bus_arn", "")
26+
STATE["logger"].debug("Initialized EVENTBRIDGE writer")
3427

35-
aws_eventbridge = boto3.client("events")
36-
EVENT_BUS_ARN = CONFIG["event_bus_arn"] if "event_bus_arn" in CONFIG else ""
37-
_logger.debug("Initialized EVENTBRIDGE writer")
3828

29+
def write(topic_name: str, message: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
30+
"""Publish a message to EventBridge.
3931
40-
def write(topicName, message) -> Tuple[bool, Optional[str]]:
41-
if not EVENT_BUS_ARN:
42-
_logger.debug("No EventBus Arn - skipping")
32+
Args:
33+
topic_name: Source topic name used as event Source.
34+
message: JSON-serializable payload.
35+
Returns:
36+
Tuple of success flag and optional error message.
37+
"""
38+
logger = STATE["logger"]
39+
event_bus_arn = STATE["event_bus_arn"]
40+
client = STATE["client"]
41+
42+
if not event_bus_arn:
43+
logger.debug("No EventBus Arn - skipping")
4344
return True, None
44-
if aws_eventbridge is None: # defensive
45-
_logger.debug("EventBridge client not initialized - skipping")
45+
if client is None: # defensive
46+
logger.debug("EventBridge client not initialized - skipping")
4647
return True, None
4748

4849
try:
49-
_logger.debug(f"Sending to eventBridge {topicName}")
50-
response = aws_eventbridge.put_events(
50+
logger.debug("Sending to eventBridge %s", topic_name)
51+
response = client.put_events(
5152
Entries=[
5253
{
53-
"Source": topicName,
54+
"Source": topic_name,
5455
"DetailType": "JSON",
5556
"Detail": json.dumps(message),
56-
"EventBusName": EVENT_BUS_ARN,
57+
"EventBusName": event_bus_arn,
5758
}
5859
]
5960
)
60-
if response["FailedEntryCount"] > 0:
61+
if response.get("FailedEntryCount", 0) > 0:
6162
msg = str(response)
62-
_logger.error(msg)
63+
logger.error(msg)
6364
return False, msg
64-
except Exception as e:
65-
err_msg = f"The EventBridge writer failed with unknown error: {str(e)}"
66-
_logger.error(err_msg)
65+
except (BotoCoreError, ClientError) as err:
66+
err_msg = f"The EventBridge writer failed: {err}" # specific AWS error
67+
logger.error(err_msg)
68+
return False, err_msg
69+
except Exception as err: # pragma: no cover - unexpected failure path
70+
err_msg = f"The EventBridge writer failed with unknown error: {err}" \
71+
if not isinstance(err, (BotoCoreError, ClientError)) else str(err)
72+
logger.error(err_msg)
6773
return False, err_msg
6874

6975
return True, None

0 commit comments

Comments
 (0)