Skip to content

Commit d0af716

Browse files
committed
Add support for bidirectional streaming in smithy clients
1 parent e70bb17 commit d0af716

File tree

19 files changed

+349
-102
lines changed

19 files changed

+349
-102
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthEventStreamV4Signer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ namespace Aws
5555

5656
bool SignEventMessage(Aws::Utils::Event::Message&, Aws::String& priorSignature) const override;
5757

58+
bool SignEventMessageWithCreds(Aws::Utils::Event::Message& message, Aws::String& priorSignature, const Aws::Auth::AWSCredentials& credentials) const;
59+
5860
bool SignRequest(Aws::Http::HttpRequest& request) const override
5961
{
6062
return SignRequest(request, m_region.c_str(), m_serviceName.c_str(), true);
@@ -70,6 +72,8 @@ namespace Aws
7072
return SignRequest(request, region, m_serviceName.c_str(), signBody);
7173
}
7274

75+
bool SignRequestWithCreds(Http::HttpRequest& request, const Auth::AWSCredentials& credentials, const char* region, const char* serviceName, bool) const;
76+
7377
bool SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const override;
7478

7579
/**

src/aws-cpp-sdk-core/include/aws/core/client/AWSClientEventStreamingAsyncTask.h

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
#include <aws/core/utils/event/EventStreamHandler.h>
1414
#include <aws/core/utils/logging/ErrorMacros.h>
1515

16+
17+
namespace smithy {
18+
namespace client {
19+
class AwsSmithyClientAsyncRequestContext;
20+
}
21+
}
22+
1623
namespace Aws {
1724
namespace Client {
1825

@@ -125,6 +132,92 @@ class AWS_CORE_LOCAL BidirectionalEventStreamingTask final {
125132
std::shared_ptr<Aws::Utils::Threading::Semaphore> m_sem;
126133
};
127134

135+
/**
136+
* Smithy-compatible bi-directional streaming task for modern AWS clients
137+
*/
138+
template <typename OutcomeT, typename ClientT, typename RequestT, typename HandlerT>
139+
class AWS_CORE_LOCAL SmithyBidirectionalStreamingTask final {
140+
public:
141+
using AuthResolvedCallback = std::function<void(std::shared_ptr<smithy::client::AwsSmithyClientAsyncRequestContext>)>;
142+
using EndpointUpdateCallback = std::function<void(Aws::Endpoint::AWSEndpoint&)>;
143+
144+
SmithyBidirectionalStreamingTask(const ClientT* client, const std::shared_ptr<RequestT>& request, const HandlerT& handler,
145+
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context,
146+
const std::shared_ptr<Utils::Event::EventEncoderStream>& stream,
147+
EndpointUpdateCallback&& endpointCallback,
148+
AuthResolvedCallback&& authCallback)
149+
: m_client(client), m_request(request), m_handler(handler), m_context(context), m_stream(stream),
150+
m_sem(Aws::MakeShared<Aws::Utils::Threading::Semaphore>("SmithyBidirectionalStreamingTask", 0, 1)) {
151+
152+
m_authCallback = std::move(authCallback);
153+
m_endpointCallback = std::move(endpointCallback);
154+
155+
m_request->SetEventStreamHandler(m_request->GetEventStreamHandler());
156+
157+
auto streamPtr = m_stream;
158+
auto sem = m_sem;
159+
m_request->SetRequestSignedHandler([streamPtr, sem](const Aws::Http::HttpRequest& httpRequest) {
160+
streamPtr->SetSignatureSeed(Aws::Client::GetAuthorizationHeader(httpRequest));
161+
sem->ReleaseAll();
162+
});
163+
164+
std::weak_ptr<RequestT> wRequest = request;
165+
// Setup InitialResponse handler to use the new actual request object
166+
if (!request->GetHeadersReceivedEventHandler()) {
167+
request->SetHeadersReceivedEventHandler([wRequest](const Http::HttpRequest*, Http::HttpResponse* response) {
168+
auto request = wRequest.lock();
169+
AWS_CHECK_PTR(ClientT::GetAllocationTag(), request);
170+
AWS_CHECK_PTR(ClientT::GetAllocationTag(), response);
171+
172+
auto& initialResponseHandler = request->GetEventStreamHandler().GetInitialResponseCallbackEx();
173+
if (initialResponseHandler) {
174+
initialResponseHandler({response->GetHeaders()}, Utils::Event::InitialResponseType::ON_RESPONSE);
175+
}
176+
});
177+
}
178+
179+
// Setup ResponseStreamFactory to provide EventStream decoder based on the new actual request object, not the original one.
180+
request->SetResponseStreamFactory([wRequest]() -> Aws::IOStream* {
181+
auto request = wRequest.lock();
182+
if (!request) {
183+
AWS_LOGSTREAM_FATAL(ClientT::GetAllocationTag(),
184+
"Unexpected nullptr bi-directional streaming request on response streaming factory call!");
185+
assert(false);
186+
return nullptr;
187+
}
188+
request->GetEventStreamDecoder().Reset();
189+
return Aws::New<Aws::Utils::Event::EventDecoderStream>("BidirectionalEventStreamingTask", request->GetEventStreamDecoder());
190+
});
191+
}
192+
193+
const std::shared_ptr<Aws::Utils::Threading::Semaphore>& GetSemaphore() const { return m_sem; }
194+
195+
void operator()() {
196+
assert(m_authCallback);
197+
assert(m_endpointCallback);
198+
auto outcome = m_client->MakeRequestDeserialize(m_request.get(), m_request->GetServiceRequestName(),
199+
Aws::Http::HttpMethod::HTTP_POST,
200+
std::move(m_endpointCallback), std::move(m_authCallback));
201+
202+
if (outcome.IsSuccess()) {
203+
m_handler(m_client, *m_request, OutcomeT(NoResult()), m_context);
204+
} else {
205+
if (m_stream) m_stream->Close();
206+
m_handler(m_client, *m_request, OutcomeT(outcome.GetError()), m_context);
207+
}
208+
}
209+
210+
private:
211+
const ClientT* m_client;
212+
std::shared_ptr<RequestT> m_request;
213+
HandlerT m_handler;
214+
std::shared_ptr<const Aws::Client::AsyncCallerContext> m_context;
215+
std::shared_ptr<Utils::Event::EventEncoderStream> m_stream;
216+
AuthResolvedCallback m_authCallback;
217+
EndpointUpdateCallback m_endpointCallback;
218+
std::shared_ptr<Aws::Utils::Threading::Semaphore> m_sem;
219+
};
220+
128221
// A helper template factory to avoid providing all typenames for BidirectionalEventStreamingTask in the generated code
129222
// It looks like a wall of code, you can thank clang-format for this.
130223
template <typename OutcomeT, typename ClientT, typename AWSEndpointT, typename RequestT, typename HandlerT>
@@ -136,5 +229,15 @@ static BidirectionalEventStreamingTask<OutcomeT, ClientT, AWSEndpointT, RequestT
136229
return BidirectionalEventStreamingTask<OutcomeT, ClientT, AWSEndpointT, RequestT, HandlerT>(
137230
pClientThis, std::forward<AWSEndpointT>(endpoint), pRequest, handler, handlerContext, stream, method, signerName);
138231
}
232+
233+
template <typename OutcomeT, typename ClientT, typename RequestT, typename HandlerT>
234+
static SmithyBidirectionalStreamingTask<OutcomeT, ClientT, RequestT, HandlerT> CreateSmithyBidirectionalEventStreamTask(
235+
const ClientT* client, std::shared_ptr<RequestT> request, const HandlerT& handler,
236+
const std::shared_ptr<const Aws::Client::AsyncCallerContext>& context,
237+
const std::shared_ptr<Utils::Event::EventEncoderStream>& stream,
238+
std::function<void(Aws::Endpoint::AWSEndpoint&)>&& endpointCallback,
239+
std::function<void(std::shared_ptr<smithy::client::AwsSmithyClientAsyncRequestContext>)>&& authCallback) {
240+
return SmithyBidirectionalStreamingTask<OutcomeT, ClientT, RequestT, HandlerT>(client, request, handler, context, stream, std::move(endpointCallback), std::move(authCallback));
241+
}
139242
} // namespace Client
140243
} // namespace Aws

src/aws-cpp-sdk-core/include/aws/core/utils/event/EventEncoderStream.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <aws/core/utils/event/EventMessage.h>
1010
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
1111
#include <aws/core/utils/event/EventStreamEncoder.h>
12+
#include <functional>
1213

1314
namespace Aws
1415
{
@@ -53,6 +54,11 @@ namespace Aws
5354
*/
5455
void SetSigner(Aws::Client::AWSAuthSigner* signer) { m_encoder.SetSigner(signer); }
5556

57+
/**
58+
* Sets a custom signing callback for event signing.
59+
*/
60+
void SetSigningCallback(const EventStreamEncoder::SigningCallback& callback) { m_encoder.SetSigningCallback(callback); }
61+
5662
/**
5763
* Allows a stream writer to communicate the end of the stream to a stream reader.
5864
*

src/aws-cpp-sdk-core/include/aws/core/utils/event/EventStreamEncoder.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <aws/core/Core_EXPORTS.h>
99
#include <aws/core/utils/memory/stl/AWSVector.h>
1010
#include <aws/event-stream/event_stream.h>
11+
#include <functional>
12+
#include <aws/core/auth/AWSAuthSigner.h>
1113

1214
namespace Aws
1315
{
@@ -28,13 +30,16 @@ namespace Aws
2830
class AWS_CORE_API EventStreamEncoder
2931
{
3032
public:
31-
EventStreamEncoder(Aws::Client::AWSAuthSigner* signer = nullptr);
33+
using SigningCallback = std::function<bool(Aws::Utils::Event::Message&, Aws::String&)>;
3234

35+
EventStreamEncoder(Aws::Client::AWSAuthSigner* signer = nullptr);
3336

3437
void SetSignatureSeed(const Aws::String& seed) { m_signatureSeed = seed; }
3538

3639
void SetSigner(Aws::Client::AWSAuthSigner* signer) { m_signer = signer; }
3740

41+
void SetSigningCallback(const SigningCallback& callback) { m_signingCallback = callback; }
42+
3843
/**
3944
* Encodes the input message in the event-stream binary format and signs the resulting bits.
4045
* The signing is done via the signer member.
@@ -58,6 +63,11 @@ namespace Aws
5863

5964
Aws::Client::AWSAuthSigner* m_signer;
6065
Aws::String m_signatureSeed;
66+
67+
SigningCallback m_signingCallback = [this](Aws::Utils::Event::Message& signedMessage, Aws::String& signatureSeed) -> bool {
68+
assert(m_signer);
69+
return m_signer->SignEventMessage(signedMessage, signatureSeed);
70+
};
6171
};
6272
}
6373
}

src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClient.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
#include <smithy/identity/signer/built-in/SignerProperties.h>
2626
#include <smithy/client/AwsLegacyClient.h>
2727

28+
namespace Aws {
29+
namespace Client {
30+
template <typename OutcomeT, typename ClientT, typename RequestT, typename HandlerT>
31+
class SmithyBidirectionalStreamingTask;
32+
}
33+
}
34+
2835
namespace smithy {
2936
namespace client
3037
{
@@ -132,6 +139,9 @@ namespace client
132139
virtual ~AwsSmithyClientT() = default;
133140

134141
protected:
142+
template <typename OutcomeT, typename ClientT, typename RequestT, typename HandlerT>
143+
friend class Aws::Client::SmithyBidirectionalStreamingTask;
144+
135145
void initClient() {
136146
if (m_endpointProvider && m_authSchemeResolver) {
137147
m_endpointProvider->InitBuiltInParameters(m_clientConfiguration);
@@ -211,6 +221,11 @@ namespace client
211221
return AwsClientRequestSigning<AuthSchemesVariantT>::SignRequest(httpRequest, ctx, m_authSchemes);
212222
}
213223

224+
SigningEventOutcome SignEventMessage(Aws::Utils::Event::Message& message, Aws::String &seed, const std::shared_ptr<AwsSmithyClientAsyncRequestContext>& ctx) const
225+
{
226+
return AwsClientRequestSigning<AuthSchemesVariantT>::SignEventMessage(message, seed, ctx, m_authSchemes);
227+
}
228+
214229
bool AdjustClockSkew(HttpResponseOutcome& outcome, const AuthSchemeOption& authSchemeOption) const override
215230
{
216231
return AwsClientRequestSigning<AuthSchemesVariantT>::AdjustClockSkew(outcome, authSchemeOption, m_authSchemes);
@@ -227,12 +242,13 @@ namespace client
227242
ResponseT MakeRequestDeserialize(Aws::AmazonWebServiceRequest const * const request,
228243
const char* requestName,
229244
Aws::Http::HttpMethod method,
230-
EndpointUpdateCallback&& endpointCallback) const
245+
EndpointUpdateCallback&& endpointCallback,
246+
AuthResolvedCallback&& authCallback = nullptr) const
231247
{
232-
auto httpResponseOutcome = MakeRequestSync(request, requestName, method, std::move(endpointCallback));
233-
return m_serializer->Deserialize(std::move(httpResponseOutcome), GetServiceClientName(), requestName);
248+
auto httpResponseOutcome = MakeRequestSync(request, requestName, method, std::move(endpointCallback), std::move(authCallback));
249+
return m_serializer->Deserialize(std::move(httpResponseOutcome), GetServiceClientName(), requestName);
234250
}
235-
251+
236252
Aws::String GeneratePresignedUrl(
237253
EndpointUpdateCallback&& endpointCallback,
238254
Aws::Http::HttpMethod method,

src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace smithy
2424
using AwsCoreError = Aws::Client::AWSError<Aws::Client::CoreErrors>;
2525
using HttpResponseOutcome = Aws::Utils::Outcome<std::shared_ptr<Aws::Http::HttpResponse>, AwsCoreError>;
2626
using ResponseHandlerFunc = std::function<void(HttpResponseOutcome&&)>;
27+
using AuthResolvedCallback = std::function<void(std::shared_ptr<AwsSmithyClientAsyncRequestContext>)>;
2728

2829
struct RequestInfo
2930
{
@@ -69,6 +70,7 @@ namespace smithy
6970
Aws::Vector<void*> m_monitoringContexts;
7071

7172
ResponseHandlerFunc m_responseHandler;
73+
AuthResolvedCallback m_authResolvedCallback;
7274
std::shared_ptr<Aws::Utils::Threading::Executor> m_pExecutor;
7375
std::shared_ptr<interceptor::InterceptorContext> m_interceptorContext;
7476
std::shared_ptr<smithy::AwsIdentity> m_awsIdentity;

src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <aws/core/http/HttpClient.h>
2525
#include <aws/core/client/AWSErrorMarshaller.h>
2626
#include <aws/core/AmazonWebServiceResult.h>
27+
#include <smithy/identity/identity/AwsIdentity.h>
2728
#include <utility>
2829

2930
namespace Aws
@@ -81,7 +82,9 @@ namespace client
8182
using ClientError = AWSCoreError;
8283
using SigningError = AWSCoreError;
8384
using SigningOutcome = Aws::Utils::FutureOutcome<std::shared_ptr<Aws::Http::HttpRequest>, SigningError>;
85+
using SigningEventOutcome = Aws::Utils::Outcome<Aws::Utils::Event::Message, SigningError>;
8486
using EndpointUpdateCallback = std::function<void(Aws::Endpoint::AWSEndpoint&)>;
87+
using AuthResolvedCallback = std::function<void(std::shared_ptr<AwsSmithyClientAsyncRequestContext>)>;
8588
using HttpResponseOutcome = Aws::Utils::Outcome<std::shared_ptr<Aws::Http::HttpResponse>, AWSCoreError>;
8689
using ResponseHandlerFunc = std::function<void(HttpResponseOutcome&&)>;
8790
using SelectAuthSchemeOptionOutcome = Aws::Utils::Outcome<AuthSchemeOption, AWSCoreError>;
@@ -144,12 +147,14 @@ namespace client
144147
Aws::Http::HttpMethod method,
145148
EndpointUpdateCallback&& endpointCallback,
146149
ResponseHandlerFunc&& responseHandler,
150+
AuthResolvedCallback&& authCallback,
147151
std::shared_ptr<Aws::Utils::Threading::Executor> pExecutor) const;
148152

149153
HttpResponseOutcome MakeRequestSync(Aws::AmazonWebServiceRequest const * const request,
150154
const char* requestName,
151155
Aws::Http::HttpMethod method,
152-
EndpointUpdateCallback&& endpointCallback) const;
156+
EndpointUpdateCallback&& endpointCallback,
157+
AuthResolvedCallback&& authCallback) const;
153158

154159
StreamOutcome MakeRequestWithUnparsedResponse(Aws::AmazonWebServiceRequest const * const request,
155160
const char* requestName,
@@ -159,6 +164,8 @@ namespace client
159164
void AppendToUserAgent(const Aws::String& valueToAppend);
160165

161166
protected:
167+
template <typename OutcomeT, typename ClientT, typename RequestT, typename HandlerT>
168+
friend class SmithyBidirectionalEventStreamingTask;
162169

163170
//for backwards compatibility
164171
const std::shared_ptr<Aws::Client::AWSErrorMarshaller>& GetErrorMarshaller() const

0 commit comments

Comments
 (0)