diff --git a/aws_lambda_powertools/utilities/parser/functions.py b/aws_lambda_powertools/utilities/parser/functions.py index 351e214da93..c133df96868 100644 --- a/aws_lambda_powertools/utilities/parser/functions.py +++ b/aws_lambda_powertools/utilities/parser/functions.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING, Any -from pydantic import TypeAdapter +from pydantic import IPvAnyNetwork, TypeAdapter from aws_lambda_powertools.shared.cache_dict import LRUDict @@ -82,3 +82,28 @@ def _parse_and_validate_event(data: dict[str, Any] | Any, adapter: TypeAdapter): data = json.loads(data) return adapter.validate_python(data) + + +def _validate_source_ip(value): + """ + Handle sourceIp that may come with port (e.g., "10.1.15.242:39870") + in certain network configurations like Cloudflare + CloudFront + API Gateway. + Validates the IP part while preserving the original format. + See: https://github.com/aws-powertools/powertools-lambda-python/issues/7288 + """ + + if value == "test-invoke-source-ip": + return value + + try: + # The value is always an instance of str before Pydantic validation occurs. + # So the first thing to do is try to convert it. + IPvAnyNetwork(value) + except ValueError: + try: + ip_part = value.split(":")[0] + IPvAnyNetwork(ip_part) + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid IP address in sourceIp: {ip_part}") from e + + return value diff --git a/aws_lambda_powertools/utilities/parser/models/apigw.py b/aws_lambda_powertools/utilities/parser/models/apigw.py index 55d2b5c7c93..ea01a5e8a6b 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw.py @@ -1,9 +1,11 @@ from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Type, Union -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, field_validator, model_validator from pydantic.networks import IPvAnyNetwork +from aws_lambda_powertools.utilities.parser.functions import _validate_source_ip + class ApiGatewayUserCertValidity(BaseModel): notBefore: str @@ -31,12 +33,17 @@ class APIGatewayEventIdentity(BaseModel): principalOrgId: Optional[str] = None # see #1562, temp workaround until API Gateway fixes it the Test button payload # removing it will not be considered a regression in the future - sourceIp: Union[IPvAnyNetwork, Literal["test-invoke-source-ip"]] + sourceIp: Union[IPvAnyNetwork, str] user: Optional[str] = None userAgent: Optional[str] = None userArn: Optional[str] = None clientCert: Optional[ApiGatewayUserCert] = None + @field_validator("sourceIp", mode="before") + @classmethod + def _validate_source_ip(cls, value): + return _validate_source_ip(value=value) + class APIGatewayEventAuthorizer(BaseModel): claims: Optional[Dict[str, Any]] = None diff --git a/aws_lambda_powertools/utilities/parser/models/apigwv2.py b/aws_lambda_powertools/utilities/parser/models/apigwv2.py index 540e7c1a30b..9bd66b7a585 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigwv2.py +++ b/aws_lambda_powertools/utilities/parser/models/apigwv2.py @@ -1,9 +1,11 @@ from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Type, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic.networks import IPvAnyNetwork +from aws_lambda_powertools.utilities.parser.functions import _validate_source_ip + class RequestContextV2AuthorizerIamCognito(BaseModel): amr: List[str] @@ -36,9 +38,14 @@ class RequestContextV2Http(BaseModel): method: Literal["DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] path: str protocol: str - sourceIp: IPvAnyNetwork + sourceIp: Union[IPvAnyNetwork, str] userAgent: str + @field_validator("sourceIp", mode="before") + @classmethod + def _validate_source_ip(cls, value): + return _validate_source_ip(value=value) + class RequestContextV2(BaseModel): accountId: str diff --git a/tests/unit/parser/_pydantic/test_apigw.py b/tests/unit/parser/_pydantic/test_apigw.py index 9fdf623bcf9..4222efa10b2 100644 --- a/tests/unit/parser/_pydantic/test_apigw.py +++ b/tests/unit/parser/_pydantic/test_apigw.py @@ -105,6 +105,21 @@ def test_apigw_event(): assert identity.apiKeyId is None +def test_apigw_event_and_source_ip_with_port(): + raw_event = load_event("apiGatewayProxyEvent.json") + raw_event["requestContext"]["identity"]["sourceIp"] = "10.10.10.10:1235" + + APIGatewayProxyEventModel(**raw_event) + + +def test_apigw_event_and_source_ip_with_random_string(): + raw_event = load_event("apiGatewayProxyEvent.json") + raw_event["requestContext"]["identity"]["sourceIp"] = "NON_IP_WITH_OR_WITHOUT_PORT_STRING" + + with pytest.raises(ValidationError): + APIGatewayProxyEventModel(**raw_event) + + def test_apigw_event_with_invalid_websocket_request(): # GIVEN an event with an eventType != MESSAGE and has a messageId event = { diff --git a/tests/unit/parser/_pydantic/test_apigwv2.py b/tests/unit/parser/_pydantic/test_apigwv2.py index ddb849bb68a..7db8d92ff2a 100644 --- a/tests/unit/parser/_pydantic/test_apigwv2.py +++ b/tests/unit/parser/_pydantic/test_apigwv2.py @@ -1,3 +1,6 @@ +import pytest +from pydantic import ValidationError + from aws_lambda_powertools.utilities.parser import envelopes, parse from aws_lambda_powertools.utilities.parser.models import ( ApiGatewayAuthorizerRequestV2, @@ -71,6 +74,21 @@ def test_apigw_v2_event_empty_jwt_scopes(): APIGatewayProxyEventV2Model(**raw_event) +def test_apigw_v2_event_and_source_ip_with_port(): + raw_event = load_event("apiGatewayProxyV2Event.json") + raw_event["requestContext"]["http"]["sourceIp"] = "10.10.10.10:1235" + + APIGatewayProxyEventV2Model(**raw_event) + + +def test_apigw_v2_event_and_source_ip_with_random_string(): + raw_event = load_event("apiGatewayProxyV2Event.json") + raw_event["requestContext"]["http"]["sourceIp"] = "NON_IP_WITH_OR_WITHOUT_PORT_STRING" + + with pytest.raises(ValidationError): + APIGatewayProxyEventV2Model(**raw_event) + + def test_api_gateway_proxy_v2_event_lambda_authorizer(): raw_event = load_event("apiGatewayProxyV2LambdaAuthorizerEvent.json") parsed_event: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model(**raw_event) diff --git a/tests/unit/test_shared_functions.py b/tests/unit/test_shared_functions.py index 1bf7c6e26a0..7f9effdb5e7 100644 --- a/tests/unit/test_shared_functions.py +++ b/tests/unit/test_shared_functions.py @@ -11,7 +11,6 @@ from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import ( abs_lambda_path, - slice_dictionary, extract_event_from_common_models, powertools_debug_is_set, powertools_dev_is_set, @@ -19,6 +18,7 @@ resolve_max_age, resolve_truthy_env_var_choice, sanitize_xray_segment_name, + slice_dictionary, strtobool, ) from aws_lambda_powertools.utilities.data_classes.common import DictWrapper