Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion aws_lambda_powertools/utilities/parser/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import TYPE_CHECKING, Any

from pydantic import TypeAdapter
from pydantic import IPvAnyNetwork, TypeAdapter

from aws_lambda_powertools.shared.cache_dict import LRUDict

Expand Down Expand Up @@ -82,3 +82,28 @@ def _parse_and_validate_event(data: dict[str, Any] | Any, adapter: TypeAdapter):
data = json.loads(data)

return adapter.validate_python(data)


def _validate_source_ip(value):
"""
Handle sourceIp that may come with port (e.g., "10.1.15.242:39870")
in certain network configurations like Cloudflare + CloudFront + API Gateway.
Validates the IP part while preserving the original format.
See: https://github.com/aws-powertools/powertools-lambda-python/issues/7288
"""

if value == "test-invoke-source-ip":
return value

try:
# The value is always an instance of str before Pydantic validation occurs.
# So the first thing to do is try to convert it.
IPvAnyNetwork(value)
except ValueError:
try:
ip_part = value.split(":")[0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has "IPv6-with-port" been taken into account?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, the IPvAnyAddress model we're using supports both v4 and v6 - source

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this can break ipv6 yes in some situations. I need to run some tests before confirming this.

Thanks @iBug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think the way we split the port might actually not work with IPv6 which contains multiple :. To make it work we should've taken everything but the last item of the split list instead of just the first. cc @leandrodamascena

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this can break ipv6 yes in some situations.

That's what I meant. I don't think it would parse when value.split(":")[0] returns something like [2001. A better approach might be

ip_part = value.rsplit(":", 1)[0]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And probably with .strip("[]") as well, if IPvAnyNetwork doesn't expected bracketed IPv6 addresses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reoepned this issue: #7288

IPvAnyNetwork(ip_part)
except (ValueError, IndexError) as e:
raise ValueError(f"Invalid IP address in sourceIp: {ip_part}") from e

return value
11 changes: 9 additions & 2 deletions aws_lambda_powertools/utilities/parser/models/apigw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Type, Union

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, field_validator, model_validator
from pydantic.networks import IPvAnyNetwork

from aws_lambda_powertools.utilities.parser.functions import _validate_source_ip


class ApiGatewayUserCertValidity(BaseModel):
notBefore: str
Expand Down Expand Up @@ -31,12 +33,17 @@ class APIGatewayEventIdentity(BaseModel):
principalOrgId: Optional[str] = None
# see #1562, temp workaround until API Gateway fixes it the Test button payload
# removing it will not be considered a regression in the future
sourceIp: Union[IPvAnyNetwork, Literal["test-invoke-source-ip"]]
sourceIp: Union[IPvAnyNetwork, str]
user: Optional[str] = None
userAgent: Optional[str] = None
userArn: Optional[str] = None
clientCert: Optional[ApiGatewayUserCert] = None

@field_validator("sourceIp", mode="before")
@classmethod
def _validate_source_ip(cls, value):
return _validate_source_ip(value=value)


class APIGatewayEventAuthorizer(BaseModel):
claims: Optional[Dict[str, Any]] = None
Expand Down
11 changes: 9 additions & 2 deletions aws_lambda_powertools/utilities/parser/models/apigwv2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Type, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from pydantic.networks import IPvAnyNetwork

from aws_lambda_powertools.utilities.parser.functions import _validate_source_ip


class RequestContextV2AuthorizerIamCognito(BaseModel):
amr: List[str]
Expand Down Expand Up @@ -36,9 +38,14 @@ class RequestContextV2Http(BaseModel):
method: Literal["DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
path: str
protocol: str
sourceIp: IPvAnyNetwork
sourceIp: Union[IPvAnyNetwork, str]
userAgent: str

@field_validator("sourceIp", mode="before")
@classmethod
def _validate_source_ip(cls, value):
return _validate_source_ip(value=value)


class RequestContextV2(BaseModel):
accountId: str
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/parser/_pydantic/test_apigw.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ def test_apigw_event():
assert identity.apiKeyId is None


def test_apigw_event_and_source_ip_with_port():
raw_event = load_event("apiGatewayProxyEvent.json")
raw_event["requestContext"]["identity"]["sourceIp"] = "10.10.10.10:1235"

APIGatewayProxyEventModel(**raw_event)


def test_apigw_event_and_source_ip_with_random_string():
raw_event = load_event("apiGatewayProxyEvent.json")
raw_event["requestContext"]["identity"]["sourceIp"] = "NON_IP_WITH_OR_WITHOUT_PORT_STRING"

with pytest.raises(ValidationError):
APIGatewayProxyEventModel(**raw_event)


def test_apigw_event_with_invalid_websocket_request():
# GIVEN an event with an eventType != MESSAGE and has a messageId
event = {
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/parser/_pydantic/test_apigwv2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest
from pydantic import ValidationError

from aws_lambda_powertools.utilities.parser import envelopes, parse
from aws_lambda_powertools.utilities.parser.models import (
ApiGatewayAuthorizerRequestV2,
Expand Down Expand Up @@ -71,6 +74,21 @@ def test_apigw_v2_event_empty_jwt_scopes():
APIGatewayProxyEventV2Model(**raw_event)


def test_apigw_v2_event_and_source_ip_with_port():
raw_event = load_event("apiGatewayProxyV2Event.json")
raw_event["requestContext"]["http"]["sourceIp"] = "10.10.10.10:1235"

APIGatewayProxyEventV2Model(**raw_event)


def test_apigw_v2_event_and_source_ip_with_random_string():
raw_event = load_event("apiGatewayProxyV2Event.json")
raw_event["requestContext"]["http"]["sourceIp"] = "NON_IP_WITH_OR_WITHOUT_PORT_STRING"

with pytest.raises(ValidationError):
APIGatewayProxyEventV2Model(**raw_event)


def test_api_gateway_proxy_v2_event_lambda_authorizer():
raw_event = load_event("apiGatewayProxyV2LambdaAuthorizerEvent.json")
parsed_event: APIGatewayProxyEventV2Model = APIGatewayProxyEventV2Model(**raw_event)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_shared_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import (
abs_lambda_path,
slice_dictionary,
extract_event_from_common_models,
powertools_debug_is_set,
powertools_dev_is_set,
resolve_env_var_choice,
resolve_max_age,
resolve_truthy_env_var_choice,
sanitize_xray_segment_name,
slice_dictionary,
strtobool,
)
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
Expand Down