Skip to content

Commit 072adf4

Browse files
authored
Add http and restjson1 client protocols (smithy-lang#448)
1 parent e1ecd23 commit 072adf4

File tree

6 files changed

+238
-5
lines changed

6 files changed

+238
-5
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Final
2+
3+
from smithy_aws_core.traits import RestJson1Trait
4+
from smithy_http.aio.protocols import HttpBindingClientProtocol
5+
from smithy_core.codecs import Codec
6+
from smithy_core.shapes import ShapeID
7+
from smithy_json import JSONCodec
8+
9+
10+
class RestJsonClientProtocol(HttpBindingClientProtocol):
11+
"""An implementation of the aws.protocols#restJson1 protocol."""
12+
13+
_id: ShapeID = RestJson1Trait.id
14+
_codec: JSONCodec = JSONCodec()
15+
_contentType: Final = "application/json"
16+
17+
@property
18+
def id(self) -> ShapeID:
19+
return self._id
20+
21+
@property
22+
def payload_codec(self) -> Codec:
23+
return self._codec
24+
25+
@property
26+
def content_type(self) -> str:
27+
return self._contentType
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# This ruff check warns against using the assert statement, which can be stripped out
5+
# when running Python with certain (common) optimization settings. Assert is used here
6+
# for trait values. Since these are always generated, we can be fairly confident that
7+
# they're correct regardless, so it's okay if the checks are stripped out.
8+
# ruff: noqa: S101
9+
10+
from dataclasses import dataclass, field
11+
from typing import Mapping, Sequence
12+
13+
from smithy_core.shapes import ShapeID
14+
from smithy_core.traits import Trait, DocumentValue, DynamicTrait
15+
16+
17+
@dataclass(init=False, frozen=True)
18+
class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")):
19+
http: Sequence[str] = field(
20+
repr=False, hash=False, compare=False, default_factory=tuple
21+
)
22+
event_stream_http: Sequence[str] = field(
23+
repr=False, hash=False, compare=False, default_factory=tuple
24+
)
25+
26+
def __init__(self, value: DocumentValue | DynamicTrait = None):
27+
super().__init__(value)
28+
assert isinstance(self.document_value, Mapping)
29+
30+
http_versions = self.document_value["http"]
31+
assert isinstance(http_versions, Sequence)
32+
for val in http_versions:
33+
assert isinstance(val, str)
34+
object.__setattr__(self, "http", tuple(http_versions))
35+
event_stream_http_versions = self.document_value.get("eventStreamHttp")
36+
if not event_stream_http_versions:
37+
object.__setattr__(self, "event_stream_http", self.http)
38+
else:
39+
assert isinstance(event_stream_http_versions, Sequence)
40+
for val in event_stream_http_versions:
41+
assert isinstance(val, str)
42+
object.__setattr__(
43+
self, "event_stream_http", tuple(event_stream_http_versions)
44+
)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
from inspect import iscoroutinefunction
3+
from io import BytesIO
4+
5+
from smithy_core.aio.interfaces import ClientProtocol
6+
from smithy_core.codecs import Codec
7+
from smithy_core.deserializers import DeserializeableShape
8+
from smithy_core.documents import TypeRegistry
9+
from smithy_core.exceptions import ExpectationNotMetException
10+
from smithy_core.interfaces import Endpoint, TypedProperties, URI
11+
from smithy_core.schemas import APIOperation
12+
from smithy_core.serializers import SerializeableShape
13+
from smithy_core.traits import HTTPTrait, EndpointTrait
14+
from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
15+
from smithy_http.deserializers import HTTPResponseDeserializer
16+
from smithy_http.serializers import HTTPRequestSerializer
17+
18+
19+
class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
20+
"""An HTTP-based protocol."""
21+
22+
def set_service_endpoint(
23+
self,
24+
*,
25+
request: HTTPRequest,
26+
endpoint: Endpoint,
27+
) -> HTTPRequest:
28+
uri = endpoint.uri
29+
uri_builder = request.destination
30+
31+
if uri.scheme:
32+
uri_builder.scheme = uri.scheme
33+
if uri.host:
34+
uri_builder.host = uri.host
35+
if uri.port and uri.port > -1:
36+
uri_builder.port = uri.port
37+
if uri.path:
38+
uri_builder.path = os.path.join(uri.path, uri_builder.path or "")
39+
# TODO: merge headers from the endpoint properties bag
40+
return request
41+
42+
43+
class HttpBindingClientProtocol(HttpClientProtocol):
44+
"""An HTTP-based protocol that uses HTTP binding traits."""
45+
46+
@property
47+
def payload_codec(self) -> Codec:
48+
"""The codec used for the serde of input and output payloads."""
49+
...
50+
51+
@property
52+
def content_type(self) -> str:
53+
"""The media type of the http payload."""
54+
...
55+
56+
def serialize_request[
57+
OperationInput: "SerializeableShape",
58+
OperationOutput: "DeserializeableShape",
59+
](
60+
self,
61+
*,
62+
operation: APIOperation[OperationInput, OperationOutput],
63+
input: OperationInput,
64+
endpoint: URI,
65+
context: TypedProperties,
66+
) -> HTTPRequest:
67+
# TODO(optimization): request binding cache like done in SJ
68+
serializer = HTTPRequestSerializer(
69+
payload_codec=self.payload_codec,
70+
http_trait=operation.schema.expect_trait(HTTPTrait),
71+
endpoint_trait=operation.schema.get_trait(EndpointTrait),
72+
)
73+
74+
input.serialize(serializer=serializer)
75+
request = serializer.result
76+
77+
if request is None:
78+
raise ExpectationNotMetException(
79+
"Expected request to be serialized, but was None"
80+
)
81+
82+
return request
83+
84+
async def deserialize_response[
85+
OperationInput: "SerializeableShape",
86+
OperationOutput: "DeserializeableShape",
87+
](
88+
self,
89+
*,
90+
operation: APIOperation[OperationInput, OperationOutput],
91+
request: HTTPRequest,
92+
response: HTTPResponse,
93+
error_registry: TypeRegistry,
94+
context: TypedProperties,
95+
) -> OperationOutput:
96+
if not (200 <= response.status <= 299):
97+
# TODO: implement error serde from type registry
98+
raise NotImplementedError
99+
100+
body = response.body
101+
102+
# if body is not streaming and is async, we have to buffer it
103+
if not operation.output_stream_member:
104+
if (
105+
read := getattr(body, "read", None)
106+
) is not None and iscoroutinefunction(read):
107+
body = BytesIO(await read())
108+
109+
# TODO(optimization): response binding cache like done in SJ
110+
deserializer = HTTPResponseDeserializer(
111+
payload_codec=self.payload_codec,
112+
http_trait=operation.schema.expect_trait(HTTPTrait),
113+
response=response,
114+
body=body, # type: ignore
115+
)
116+
117+
return operation.output.deserialize(deserializer)

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
TimestampFormatTrait,
3030
EndpointTrait,
3131
HTTPErrorTrait,
32+
MediaTypeTrait,
33+
StreamingTrait,
3234
)
3335
from smithy_core.shapes import ShapeType
3436
from smithy_core.utils import serialize_float
@@ -83,22 +85,33 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8385
if self._endpoint_trait is not None:
8486
host_prefix = self._endpoint_trait.host_prefix
8587

88+
content_type = self._payload_codec.media_type
89+
8690
if (payload_member := self._get_payload_member(schema)) is not None:
8791
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
92+
content_type = (
93+
"application/octet-stream"
94+
if payload_member.shape_type is ShapeType.BLOB
95+
else "text/plain"
96+
)
8897
payload_serializer = RawPayloadSerializer()
8998
binding_serializer = HTTPRequestBindingSerializer(
9099
payload_serializer, self._http_trait.path, host_prefix
91100
)
92101
yield binding_serializer
93102
payload = payload_serializer.payload
94103
else:
104+
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
105+
content_type = media_type.value
95106
payload = BytesIO()
96107
payload_serializer = self._payload_codec.create_serializer(payload)
97108
binding_serializer = HTTPRequestBindingSerializer(
98109
payload_serializer, self._http_trait.path, host_prefix
99110
)
100111
yield binding_serializer
101112
else:
113+
if self._get_eventstreaming_member(schema) is not None:
114+
content_type = "application/vnd.amazon.eventstream"
102115
payload = BytesIO()
103116
payload_serializer = self._payload_codec.create_serializer(payload)
104117
with payload_serializer.begin_struct(schema) as body_serializer:
@@ -112,6 +125,10 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
112125
) is not None and not iscoroutinefunction(seek):
113126
seek(0)
114127

128+
# TODO: conditional on empty-ness and based on the protocol
129+
headers = binding_serializer.header_serializer.headers
130+
headers.append(("content-type", content_type))
131+
115132
self.result = _HTTPRequest(
116133
method=self._http_trait.method,
117134
destination=URI(
@@ -122,7 +139,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
122139
prefix=self._http_trait.query or "",
123140
),
124141
),
125-
fields=tuples_to_fields(binding_serializer.header_serializer.headers),
142+
fields=tuples_to_fields(headers),
126143
body=payload,
127144
)
128145

@@ -132,6 +149,15 @@ def _get_payload_member(self, schema: Schema) -> Schema | None:
132149
return member
133150
return None
134151

152+
def _get_eventstreaming_member(self, schema: Schema) -> Schema | None:
153+
for member in schema.members.values():
154+
if (
155+
member.get_trait(StreamingTrait) is not None
156+
and member.shape_type is ShapeType.UNION
157+
):
158+
return member
159+
return None
160+
135161

136162
class HTTPRequestBindingSerializer(InterceptingSerializer):
137163
"""Delegates HTTP request bindings to binding-location-specific serializers."""

packages/smithy-http/tests/unit/test_serializers.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from smithy_http.deserializers import HTTPResponseDeserializer
4545
from smithy_json import JSONCodec
4646
from smithy_http.aio import HTTPResponse as _HTTPResponse
47-
from smithy_http import tuples_to_fields, Fields
47+
from smithy_http import tuples_to_fields, Field, Fields
4848
from smithy_http.serializers import HTTPRequestSerializer, HTTPResponseSerializer
4949

5050
# TODO: empty header prefix, query map
@@ -1572,11 +1572,16 @@ def payload_cases() -> list[HTTPMessageTestCase]:
15721572
),
15731573
HTTPMessageTestCase(
15741574
HTTPStringPayload(payload="foo"),
1575-
HTTPMessage(body=b"foo"),
1575+
HTTPMessage(
1576+
fields=tuples_to_fields([("content-type", "text/plain")]), body=b"foo"
1577+
),
15761578
),
15771579
HTTPMessageTestCase(
15781580
HTTPBlobPayload(payload=b"\xde\xad\xbe\xef"),
1579-
HTTPMessage(body=b"\xde\xad\xbe\xef"),
1581+
HTTPMessage(
1582+
fields=tuples_to_fields([("content-type", "application/octet-stream")]),
1583+
body=b"\xde\xad\xbe\xef",
1584+
),
15801585
),
15811586
HTTPMessageTestCase(
15821587
HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")),
@@ -1589,7 +1594,10 @@ def async_streaming_payload_cases() -> list[HTTPMessageTestCase]:
15891594
return [
15901595
HTTPMessageTestCase(
15911596
HTTPStreamingPayload(payload=AsyncBytesReader(b"\xde\xad\xbe\xef")),
1592-
HTTPMessage(body=AsyncBytesReader(b"\xde\xad\xbe\xef")),
1597+
HTTPMessage(
1598+
fields=tuples_to_fields([("content-type", "application/octet-stream")]),
1599+
body=AsyncBytesReader(b"\xde\xad\xbe\xef"),
1600+
),
15931601
),
15941602
]
15951603

@@ -1604,6 +1612,8 @@ def async_streaming_payload_cases() -> list[HTTPMessageTestCase]:
16041612
+ async_streaming_payload_cases()
16051613
)
16061614

1615+
CONTENT_TYPE_FIELD = Field(name="content-type", values=["application/json"])
1616+
16071617

16081618
@pytest.mark.parametrize("case", REQUEST_SER_CASES)
16091619
async def test_serialize_http_request(case: HTTPMessageTestCase) -> None:
@@ -1623,6 +1633,10 @@ async def test_serialize_http_request(case: HTTPMessageTestCase) -> None:
16231633
actual_query = actual.destination.query or ""
16241634
expected_query = case.request.destination.query or ""
16251635
assert actual_query == expected_query
1636+
# set the content-type field here, otherwise cases would have to duplicate it everywhere,
1637+
# but if the field is already set in the case, don't override it
1638+
if expected.fields.get(CONTENT_TYPE_FIELD.name) is None:
1639+
expected.fields.set_field(CONTENT_TYPE_FIELD)
16261640
assert actual.fields == expected.fields
16271641

16281642
if case.request.body:
@@ -1647,6 +1661,9 @@ async def test_serialize_http_response(case: HTTPMessageTestCase) -> None:
16471661
expected = case.request
16481662

16491663
assert actual is not None
1664+
# Remove content-type from expected, we're re-using the request cases for brevity
1665+
if expected.fields.get(CONTENT_TYPE_FIELD.name) is not None:
1666+
del expected.fields[CONTENT_TYPE_FIELD.name]
16501667
assert actual.fields == expected.fields
16511668
assert actual.status == expected.status
16521669

0 commit comments

Comments
 (0)