Skip to content

Commit 05673d6

Browse files
committed
Add EventSigner implementation
1 parent e1ecd23 commit 05673d6

File tree

7 files changed

+186
-7
lines changed

7 files changed

+186
-7
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ await sleep(retry_token.retry_delay)
570570

571571
writer.pushState(new SignRequestSection());
572572
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
573+
writer.addStdlibImport("binascii", "hexlify");
574+
writer.addStdlibImport("re");
573575
writer.write("""
574576
# Step 7i: sign the request
575577
if auth_option and signer:
@@ -587,6 +589,14 @@ await sleep(retry_token.retry_delay)
587589
)
588590
)
589591
logger.debug("Signed HTTP request: %s", context.transport_request)
592+
593+
# TODO - Move this to separate resolution/population function
594+
fields = context._transport_request.fields
595+
auth_value = fields["Authorization"].as_string() # type: ignore
596+
signature = re.split("Signature=", auth_value)[-1] # type: ignore
597+
context._properties["signature"] = hexlify(signature.encode('utf-8')) # type: ignore
598+
context._properties["identity"] = identity
599+
context._properties["signer_properties"] = auth_option.signer_properties
590600
""");
591601
}
592602
writer.popState();

codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ public void wrapOutputStream(GenerationContext context, PythonWriter writer) {
424424
transport_response.body # type: ignore
425425
),
426426
deserializer=event_deserializer, # type: ignore
427+
signer=signer, # type: ignore
427428
)
428429
""");
429430
}

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Iterator
66
from contextlib import contextmanager
77
from io import BytesIO
8-
from typing import Never
8+
from typing import Never, Protocol
99

1010
from smithy_core.codecs import Codec
1111
from smithy_core.schemas import Schema
@@ -16,7 +16,7 @@
1616
)
1717
from smithy_core.shapes import ShapeType
1818

19-
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
19+
from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long
2020
from ..exceptions import InvalidHeaderValue
2121
from . import (
2222
INITIAL_REQUEST_EVENT_TYPE,
@@ -43,6 +43,7 @@ def __init__(
4343
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
4444
else:
4545
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE
46+
self.event_header_encoder_cls = EventHeaderEncoder
4647

4748
def get_result(self) -> EventMessage | None:
4849
return self._result

packages/aws-event-stream/src/aws_event_stream/aio/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,23 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
type Signer = Callable[[EventMessage], EventMessage]
24-
"""A function that takes an event message and signs it, and returns it signed."""
23+
class EventSigner(Protocol):
24+
"""A signer to manage credentials and EventMessages for an Event Stream lifecyle."""
25+
26+
def sign_event(
27+
self,
28+
*,
29+
event_message: EventMessage,
30+
event_encoder_cls: type[EventHeaderEncoder],
31+
) -> EventMessage: ...
2532

2633

2734
class AWSEventPublisher[E: SerializeableShape](EventPublisher[E]):
2835
def __init__(
2936
self,
3037
payload_codec: Codec,
3138
async_writer: AsyncWriter,
32-
signer: Signer | None = None,
39+
signer: EventSigner | None = None,
3340
is_client_mode: bool = True,
3441
):
3542
self._writer = async_writer

packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
from ._http import URI, AWSRequest, Field, Fields
1010
from ._identity import AWSCredentialIdentity
1111
from ._io import AsyncBytesReader
12-
from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties
12+
from .signers import (
13+
AsyncSigV4Signer,
14+
AsyncEventSigner,
15+
SigV4Signer,
16+
SigV4SigningProperties,
17+
)
1318

1419
__license__ = "Apache-2.0"
1520
__version__ = importlib.metadata.version("aws-sdk-signers")
1621

1722
__all__ = (
1823
"AsyncBytesReader",
1924
"AsyncSigV4Signer",
25+
"AsyncEventSigner",
2026
"AWSCredentialIdentity",
2127
"AWSRequest",
2228
"Field",
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import datetime
7+
import uuid
8+
from collections.abc import Mapping
9+
from typing import Protocol
10+
11+
12+
type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID
13+
"""A union of valid value types for event headers."""
14+
15+
16+
type HEADERS_DICT = Mapping[str, HEADER_VALUE]
17+
"""A dictionary of event headers."""
18+
19+
20+
class EventMessage(Protocol):
21+
"""A signable message that may be sent over an event stream."""
22+
23+
headers: HEADERS_DICT
24+
"""The headers present in the event message."""
25+
26+
payload: bytes
27+
"""The serialized bytes of the message payload."""
28+
29+
def encode(self) -> bytes:
30+
"""Encode heads and payload into bytes for transit."""
31+
...
32+
33+
34+
class EventHeaderEncoder(Protocol):
35+
"""A utility class that encodes event headers into bytes."""
36+
37+
def clear(self) -> None:
38+
"""Clear all previously encoded headers."""
39+
...
40+
41+
def get_result(self) -> bytes:
42+
"""Get all the encoded header bytes."""
43+
...
44+
45+
def encode_headers(self, headers: HEADERS_DICT) -> None:
46+
"""Encode a map of headers."""
47+
...

packages/aws-sdk-signers/src/aws_sdk_signers/signers.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import asyncio
45
import datetime
56
import hmac
67
import io
78
import warnings
89
from asyncio import iscoroutinefunction
10+
from binascii import hexlify
911
from collections.abc import AsyncIterable, Iterable
1012
from copy import deepcopy
1113
from hashlib import sha256
12-
from typing import Required, TypedDict
14+
from typing import Required, TypedDict, TYPE_CHECKING
1315
from urllib.parse import parse_qsl, quote
1416

1517
from .interfaces.io import AsyncSeekable, Seekable
@@ -19,6 +21,9 @@
1921
from ._io import AsyncBytesReader
2022
from .exceptions import AWSSDKWarning, MissingExpectedParameterException
2123

24+
if TYPE_CHECKING:
25+
from .interfaces.events import EventMessage, EventHeaderEncoder
26+
2227
HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = (
2328
"accept",
2429
"accept-encoding",
@@ -789,6 +794,108 @@ async def _compute_payload_hash(
789794
return checksum.hexdigest()
790795

791796

797+
class AsyncEventSigner:
798+
def __init__(
799+
self,
800+
*,
801+
signing_properties: SigV4SigningProperties,
802+
identity: AWSCredentialIdentity,
803+
initial_signature: bytes,
804+
):
805+
self._signing_properties = signing_properties
806+
self._identity = identity
807+
self._prior_signature = initial_signature
808+
self._signing_lock = asyncio.Lock()
809+
810+
async def sign_event(
811+
self,
812+
*,
813+
event_message: "EventMessage",
814+
event_encoder_cls: type["EventHeaderEncoder"],
815+
) -> "EventMessage":
816+
async with self._signing_lock:
817+
# Copy and prepopulate any missing values in the
818+
# signing properties.
819+
new_signing_properties = SigV4SigningProperties( # type: ignore
820+
**self._signing_properties
821+
)
822+
if "date" not in new_signing_properties:
823+
date_obj = datetime.datetime.now(datetime.UTC)
824+
new_signing_properties["date"] = date_obj.strftime(
825+
SIGV4_TIMESTAMP_FORMAT
826+
)
827+
828+
timestamp = new_signing_properties["date"]
829+
headers: dict[str, str | bytes] = {":date": timestamp}
830+
encoder = event_encoder_cls()
831+
encoder.encode_headers(event_message.headers)
832+
encoded_headers = encoder.get_result()
833+
834+
string_to_sign = await self._event_string_to_sign(
835+
timestamp=timestamp,
836+
scope=self._scope(new_signing_properties),
837+
encoded_headers=encoded_headers,
838+
payload=event_message.payload,
839+
prior_signature=self._prior_signature,
840+
)
841+
event_signature = await self._sign_event(
842+
timestamp=timestamp,
843+
string_to_sign=string_to_sign,
844+
signing_properties=new_signing_properties,
845+
)
846+
headers[":chunk-signature"] = event_signature
847+
event_message.headers.update(headers) # type: ignore
848+
849+
# set new prior signature before releasing the lock
850+
self._prior_signature = event_signature
851+
852+
return event_message
853+
854+
async def _event_string_to_sign(
855+
self,
856+
*,
857+
timestamp: str,
858+
scope: str,
859+
encoded_headers: bytes,
860+
payload: bytes,
861+
prior_signature: bytes,
862+
) -> str:
863+
return (
864+
"AWS-HMAC-SHA256-PAYLOAD\n"
865+
f"{timestamp}\n"
866+
f"{scope}\n"
867+
f"{hexlify(prior_signature).decode('utf-8')}\n"
868+
f"{sha256(encoded_headers).hexdigest()}\n"
869+
f"{sha256(payload).hexdigest()}"
870+
)
871+
872+
async def _sign_event(
873+
self,
874+
*,
875+
timestamp: str,
876+
string_to_sign: str,
877+
signing_properties: SigV4SigningProperties,
878+
) -> bytes:
879+
key = self._identity.secret_access_key.encode("utf-8")
880+
today = timestamp[:8].encode("utf-8")
881+
k_date = self._hash(b"AWS4" + key, today)
882+
k_region = self._hash(k_date, signing_properties["region"].encode("utf-8"))
883+
k_service = self._hash(k_region, signing_properties["service"].encode("utf-8"))
884+
k_signing = self._hash(k_service, b"aws4_request")
885+
return self._hash(k_signing, string_to_sign.encode("utf-8"))
886+
887+
def _hash(self, key: bytes, msg: bytes) -> bytes:
888+
return hmac.new(key, msg, sha256).digest()
889+
890+
def _scope(self, signing_properties: SigV4SigningProperties) -> str:
891+
assert "date" in signing_properties
892+
formatted_date = signing_properties["date"][0:8]
893+
region = signing_properties["region"]
894+
service = signing_properties["service"]
895+
# Scope format: <YYYYMMDD>/<AWS Region>/<AWS Service>/aws4_request
896+
return f"{formatted_date}/{region}/{service}/aws4_request"
897+
898+
792899
def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str:
793900
"""Removes dot segments from a path per :rfc:`3986#section-5.2.4`.
794901

0 commit comments

Comments
 (0)