Skip to content

Commit 81a7327

Browse files
authored
Merge pull request smithy-lang#438 from smithy-lang/sign_event
Event Signing
2 parents 414a067 + fe2cb01 commit 81a7327

File tree

9 files changed

+246
-10
lines changed

9 files changed

+246
-10
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,10 @@ await sleep(retry_token.retry_delay)
570570

571571
writer.pushState(new SignRequestSection());
572572
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
573+
writer.addStdlibImport("re");
574+
writer.addStdlibImport("typing", "Any");
575+
writer.addImport("smithy_core.interfaces.identity", "Identity");
576+
writer.addImport("smithy_core.types", "PropertyKey");
573577
writer.write("""
574578
# Step 7i: sign the request
575579
if auth_option and signer:
@@ -587,6 +591,23 @@ await sleep(retry_token.retry_delay)
587591
)
588592
)
589593
logger.debug("Signed HTTP request: %s", context.transport_request)
594+
595+
# TODO - Move this to separate resolution/population function
596+
fields = context.transport_request.fields
597+
auth_value = fields["Authorization"].as_string() # type: ignore
598+
signature = re.split("Signature=", auth_value)[-1] # type: ignore
599+
context.properties["signature"] = signature.encode('utf-8')
600+
601+
identity_key: PropertyKey[Identity | None] = PropertyKey(
602+
key="identity",
603+
value_type=Identity | None # type: ignore
604+
)
605+
sp_key: PropertyKey[dict[str, Any]] = PropertyKey(
606+
key="signer_properties",
607+
value_type=dict[str, Any] # type: ignore
608+
)
609+
context.properties[identity_key] = identity
610+
context.properties[sp_key] = auth_option.signer_properties
590611
""");
591612
}
592613
writer.popState();

codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.List;
88
import java.util.Set;
99
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait;
10+
import software.amazon.smithy.model.knowledge.EventStreamIndex;
1011
import software.amazon.smithy.model.knowledge.HttpBinding;
1112
import software.amazon.smithy.model.node.ArrayNode;
1213
import software.amazon.smithy.model.node.ObjectNode;
@@ -156,6 +157,21 @@ protected void serializeDocumentBody(
156157
writer.popState();
157158
}
158159

160+
@Override
161+
protected void writeDefaultHeaders(GenerationContext context, PythonWriter writer, OperationShape operation) {
162+
var eventStreamIndex = EventStreamIndex.of(context.model());
163+
if (eventStreamIndex.getInputInfo(operation).isPresent()) {
164+
writer.addImport("smithy_http", "Field");
165+
writer.write(
166+
"Field(name=\"Content-Type\", values=[$S]),",
167+
"application/vnd.amazon.eventstream");
168+
writer.write(
169+
"Field(name=\"X-Amz-Content-SHA256\", values=[$S]),",
170+
"STREAMING-AWS4-HMAC-SHA256-EVENTS");
171+
}
172+
}
173+
174+
159175
@Override
160176
protected void serializePayloadBody(
161177
GenerationContext context,
@@ -397,12 +413,24 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) {
397413
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
398414
writer.addImport("smithy_core.types", "TimestampFormat");
399415
writer.addImport("aws_event_stream.aio", "AWSEventPublisher");
416+
writer.addImport("aws_sdk_signers", "AsyncEventSigner");
400417
writer.write(
401418
"""
419+
# TODO - Move this out of the RestJSON generator
420+
ctx = request_context
421+
signer_properties = ctx.properties.get("signer_properties") # type: ignore
422+
identity = ctx.properties.get("identity") # type: ignore
423+
signature = ctx.properties.get("signature") # type: ignore
424+
signer = AsyncEventSigner(
425+
signing_properties=signer_properties, # type: ignore
426+
identity=identity, # type: ignore
427+
initial_signature=signature, # type: ignore
428+
)
402429
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
403430
publisher = AWSEventPublisher[Any](
404431
payload_codec=codec,
405432
async_writer=request_context.transport_request.body, # type: ignore
433+
signer=signer, # type: ignore
406434
)
407435
""");
408436
}

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from smithy_core.shapes import ShapeType
1818

19-
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
19+
from ..events import EventHeaderEncoder, EventMessage, HEADER_VALUE, Short, Byte, Long
2020
from ..exceptions import InvalidHeaderValue
2121
from . import (
2222
INITIAL_REQUEST_EVENT_TYPE,
@@ -43,6 +43,7 @@ def __init__(
4343
self._initial_message_event_type = INITIAL_REQUEST_EVENT_TYPE
4444
else:
4545
self._initial_message_event_type = INITIAL_RESPONSE_EVENT_TYPE
46+
self.event_header_encoder_cls = EventHeaderEncoder
4647

4748
def get_result(self) -> EventMessage | None:
4849
return self._result

packages/aws-event-stream/src/aws_event_stream/aio/__init__.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,31 @@
1414

1515
from .._private.serializers import EventSerializer as _EventSerializer
1616
from .._private.deserializers import EventDeserializer as _EventDeserializer
17-
from ..events import Event, EventMessage
17+
from ..events import Event, EventHeaderEncoder, EventMessage
1818
from ..exceptions import EventError
1919

20+
from typing import Protocol
21+
2022
logger = logging.getLogger(__name__)
2123

2224

23-
type Signer = Callable[[EventMessage], EventMessage]
24-
"""A function that takes an event message and signs it, and returns it signed."""
25+
class EventSigner(Protocol):
26+
"""A signer to manage credentials and EventMessages for an Event Stream lifecyle."""
27+
28+
def sign_event(
29+
self,
30+
*,
31+
event_message: EventMessage,
32+
event_encoder_cls: type[EventHeaderEncoder],
33+
) -> EventMessage: ...
2534

2635

2736
class AWSEventPublisher[E: SerializeableShape](EventPublisher[E]):
2837
def __init__(
2938
self,
3039
payload_codec: Codec,
3140
async_writer: AsyncWriter,
32-
signer: Signer | None = None,
41+
signer: EventSigner | None = None,
3342
is_client_mode: bool = True,
3443
):
3544
self._writer = async_writer
@@ -50,8 +59,13 @@ async def send(self, event: E) -> None:
5059
"Expected an event message to be serialized, but was None."
5160
)
5261
if self._signer is not None:
53-
result = self._signer(result)
62+
encoder = self._serializer.event_header_encoder_cls
63+
result = await self._signer.sign_event( # type: ignore
64+
event_message=result,
65+
event_encoder_cls=encoder,
66+
)
5467

68+
assert isinstance(result, EventMessage)
5569
encoded_result = result.encode()
5670
try:
5771
logger.debug("Publishing serialized event: %s", result)

packages/aws-event-stream/src/aws_event_stream/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def get_result(self) -> bytes:
387387
raise InvalidHeadersLength(len(result))
388388
return result
389389

390-
def encode_headers(self, headers: HEADERS_DICT):
390+
def encode_headers(self, headers: HEADERS_DICT) -> None:
391391
"""Encode a map of headers.
392392
393393
:param headers: A mapping of headers to encode.

packages/aws-sdk-signers/src/aws_sdk_signers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
from ._http import URI, AWSRequest, Field, Fields
1010
from ._identity import AWSCredentialIdentity
1111
from ._io import AsyncBytesReader
12-
from .signers import AsyncSigV4Signer, SigV4Signer, SigV4SigningProperties
12+
from .signers import (
13+
AsyncSigV4Signer,
14+
AsyncEventSigner,
15+
SigV4Signer,
16+
SigV4SigningProperties,
17+
)
1318

1419
__license__ = "Apache-2.0"
1520
__version__ = importlib.metadata.version("aws-sdk-signers")
1621

1722
__all__ = (
1823
"AsyncBytesReader",
1924
"AsyncSigV4Signer",
25+
"AsyncEventSigner",
2026
"AWSCredentialIdentity",
2127
"AWSRequest",
2228
"Field",
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import datetime
7+
import uuid
8+
from collections.abc import Mapping
9+
from typing import Protocol
10+
11+
12+
type HEADER_VALUE = bool | int | bytes | str | datetime.datetime | uuid.UUID
13+
"""A union of valid value types for event headers."""
14+
15+
16+
type HEADERS_DICT = Mapping[str, HEADER_VALUE]
17+
"""A dictionary of event headers."""
18+
19+
20+
class EventMessage(Protocol):
21+
"""A signable message that may be sent over an event stream."""
22+
23+
headers: HEADERS_DICT
24+
"""The headers present in the event message."""
25+
26+
payload: bytes
27+
"""The serialized bytes of the message payload."""
28+
29+
def encode(self) -> bytes:
30+
"""Encode heads and payload into bytes for transit."""
31+
...
32+
33+
34+
class EventHeaderEncoder(Protocol):
35+
"""A utility class that encodes event headers into bytes."""
36+
37+
def clear(self) -> None:
38+
"""Clear all previously encoded headers."""
39+
...
40+
41+
def get_result(self) -> bytes:
42+
"""Get all the encoded header bytes."""
43+
...
44+
45+
def encode_headers(self, headers: HEADERS_DICT) -> None:
46+
"""Encode a map of headers."""
47+
...

packages/aws-sdk-signers/src/aws_sdk_signers/signers.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import asyncio
45
import datetime
56
import hmac
67
import io
78
import warnings
89
from asyncio import iscoroutinefunction
10+
from binascii import hexlify
911
from collections.abc import AsyncIterable, Iterable
1012
from copy import deepcopy
1113
from hashlib import sha256
12-
from typing import Required, TypedDict
14+
from typing import Required, TypedDict, TYPE_CHECKING
1315
from urllib.parse import parse_qsl, quote
1416

1517
from .interfaces.io import AsyncSeekable, Seekable
@@ -19,6 +21,9 @@
1921
from ._io import AsyncBytesReader
2022
from .exceptions import AWSSDKWarning, MissingExpectedParameterException
2123

24+
if TYPE_CHECKING:
25+
from .interfaces.events import EventMessage, EventHeaderEncoder
26+
2227
HEADERS_EXCLUDED_FROM_SIGNING: tuple[str, ...] = (
2328
"accept",
2429
"accept-encoding",
@@ -739,6 +744,12 @@ async def _format_canonical_payload(
739744
request: AWSRequest,
740745
signing_properties: SigV4SigningProperties,
741746
) -> str:
747+
if (
748+
"X-Amz-Content-SHA256" in request.fields
749+
and len(request.fields["X-Amz-Content-SHA256"].values) == 1
750+
):
751+
return request.fields["X-Amz-Content-SHA256"].values[0]
752+
742753
payload_hash = await self._compute_payload_hash(
743754
request=request, signing_properties=signing_properties
744755
)
@@ -789,6 +800,113 @@ async def _compute_payload_hash(
789800
return checksum.hexdigest()
790801

791802

803+
class AsyncEventSigner:
804+
def __init__(
805+
self,
806+
*,
807+
signing_properties: SigV4SigningProperties,
808+
identity: AWSCredentialIdentity,
809+
initial_signature: bytes,
810+
):
811+
self._signing_properties = signing_properties
812+
self._identity = identity
813+
self._prior_signature = initial_signature
814+
self._signing_lock = asyncio.Lock()
815+
816+
async def sign_event(
817+
self,
818+
*,
819+
event_message: "EventMessage",
820+
event_encoder_cls: type["EventHeaderEncoder"],
821+
) -> "EventMessage":
822+
async with self._signing_lock:
823+
# Copy and prepopulate any missing values in the
824+
# signing properties.
825+
new_signing_properties = SigV4SigningProperties( # type: ignore
826+
**self._signing_properties
827+
)
828+
# TODO: If date is in properties, parse a datetime from it.
829+
date_obj = datetime.datetime.now(datetime.UTC)
830+
if "date" not in new_signing_properties:
831+
new_signing_properties["date"] = date_obj.strftime(
832+
SIGV4_TIMESTAMP_FORMAT
833+
)
834+
835+
timestamp = new_signing_properties["date"]
836+
headers: dict[str, str | bytes | datetime.datetime] = {":date": date_obj}
837+
encoder = event_encoder_cls()
838+
encoder.encode_headers(headers)
839+
encoded_headers = encoder.get_result()
840+
841+
payload = event_message.encode()
842+
843+
string_to_sign = await self._event_string_to_sign(
844+
timestamp=timestamp,
845+
scope=self._scope(new_signing_properties),
846+
encoded_headers=encoded_headers,
847+
payload=payload,
848+
prior_signature=self._prior_signature,
849+
)
850+
event_signature = await self._sign_event(
851+
timestamp=timestamp,
852+
string_to_sign=string_to_sign,
853+
signing_properties=new_signing_properties,
854+
)
855+
headers[":chunk-signature"] = event_signature
856+
857+
event_message.headers = headers
858+
event_message.payload = payload
859+
860+
# set new prior signature before releasing the lock
861+
self._prior_signature = hexlify(event_signature)
862+
863+
return event_message
864+
865+
async def _event_string_to_sign(
866+
self,
867+
*,
868+
timestamp: str,
869+
scope: str,
870+
encoded_headers: bytes,
871+
payload: bytes,
872+
prior_signature: bytes,
873+
) -> str:
874+
return (
875+
"AWS4-HMAC-SHA256-PAYLOAD\n"
876+
f"{timestamp}\n"
877+
f"{scope}\n"
878+
f"{prior_signature.decode('utf-8')}\n"
879+
f"{sha256(encoded_headers).hexdigest()}\n"
880+
f"{sha256(payload).hexdigest()}"
881+
)
882+
883+
async def _sign_event(
884+
self,
885+
*,
886+
timestamp: str,
887+
string_to_sign: str,
888+
signing_properties: SigV4SigningProperties,
889+
) -> bytes:
890+
key = self._identity.secret_access_key.encode("utf-8")
891+
today = timestamp[:8].encode("utf-8")
892+
k_date = self._hash(b"AWS4" + key, today)
893+
k_region = self._hash(k_date, signing_properties["region"].encode("utf-8"))
894+
k_service = self._hash(k_region, signing_properties["service"].encode("utf-8"))
895+
k_signing = self._hash(k_service, b"aws4_request")
896+
return self._hash(k_signing, string_to_sign.encode("utf-8"))
897+
898+
def _hash(self, key: bytes, msg: bytes) -> bytes:
899+
return hmac.new(key, msg, sha256).digest()
900+
901+
def _scope(self, signing_properties: SigV4SigningProperties) -> str:
902+
assert "date" in signing_properties
903+
formatted_date = signing_properties["date"][0:8]
904+
region = signing_properties["region"]
905+
service = signing_properties["service"]
906+
# Scope format: <YYYYMMDD>/<AWS Region>/<AWS Service>/aws4_request
907+
return f"{formatted_date}/{region}/{service}/aws4_request"
908+
909+
792910
def _remove_dot_segments(path: str, remove_consecutive_slashes: bool = True) -> str:
793911
"""Removes dot segments from a path per :rfc:`3986#section-5.2.4`.
794912

0 commit comments

Comments
 (0)