15
15
16
16
package software .amazon .awssdk .protocol .runners ;
17
17
18
+
19
+ import static com .github .tomakehurst .wiremock .client .WireMock .aResponse ;
20
+ import static com .github .tomakehurst .wiremock .client .WireMock .any ;
21
+ import static com .github .tomakehurst .wiremock .client .WireMock .stubFor ;
22
+ import static com .github .tomakehurst .wiremock .client .WireMock .urlMatching ;
18
23
import static org .junit .Assert .assertEquals ;
19
24
20
25
import com .fasterxml .jackson .databind .JsonNode ;
26
+ import com .github .tomakehurst .wiremock .client .ResponseDefinitionBuilder ;
27
+ import com .github .tomakehurst .wiremock .client .WireMock ;
21
28
import com .github .tomakehurst .wiremock .verification .LoggedRequest ;
22
29
import java .lang .reflect .InvocationTargetException ;
23
30
import java .net .URI ;
24
31
import java .util .List ;
25
32
import org .junit .Assert ;
26
33
import software .amazon .awssdk .awscore .AwsRequest ;
27
- import software .amazon .awssdk .awscore .AwsRequestOverrideConfiguration ;
28
34
import software .amazon .awssdk .codegen .model .intermediate .IntermediateModel ;
29
- import software .amazon .awssdk .core .SdkPlugin ;
30
- import software .amazon .awssdk .core .SdkServiceClientConfiguration ;
31
35
import software .amazon .awssdk .core .interceptor .Context ;
32
36
import software .amazon .awssdk .core .interceptor .ExecutionAttributes ;
33
37
import software .amazon .awssdk .core .interceptor .ExecutionInterceptor ;
34
38
import software .amazon .awssdk .core .sync .RequestBody ;
35
- import software .amazon .awssdk .http .SdkHttpFullRequest ;
36
39
import software .amazon .awssdk .http .SdkHttpRequest ;
37
40
import software .amazon .awssdk .protocol .model .TestCase ;
38
41
import software .amazon .awssdk .protocol .reflect .ClientReflector ;
@@ -47,52 +50,50 @@ class MarshallingTestRunner {
47
50
48
51
private final IntermediateModel model ;
49
52
private final ClientReflector clientReflector ;
50
- private final RequestRecordingInterceptor recordingInterceptor ;
53
+ private final LocalhostOnlyForWiremockInterceptor localhostOnlyForWiremockInterceptor ;
51
54
52
55
MarshallingTestRunner (IntermediateModel model , ClientReflector clientReflector ) {
53
56
this .model = model ;
54
57
this .clientReflector = clientReflector ;
55
- this .recordingInterceptor = new RequestRecordingInterceptor ();
56
- }
57
-
58
- /**
59
- * @return LoggedRequest that wire mock captured.
60
- */
61
- private static LoggedRequest getLoggedRequest () {
62
- List <LoggedRequest > requests = WireMockUtils .findAllLoggedRequests ();
63
- assertEquals (1 , requests .size ());
64
- return requests .get (0 );
58
+ this .localhostOnlyForWiremockInterceptor = new LocalhostOnlyForWiremockInterceptor ();
65
59
}
66
60
67
61
void runTest (TestCase testCase ) throws Exception {
62
+ resetWireMock ();
68
63
ShapeModelReflector shapeModelReflector = createShapeModelReflector (testCase );
69
64
AwsRequest request = createRequest (testCase , shapeModelReflector );
70
65
71
- try {
72
- if (!model .getShapes ().get (testCase .getWhen ().getOperationName () + "Request" ).isHasStreamingMember ()) {
73
- clientReflector .invokeMethod (testCase , request );
74
- } else {
75
- clientReflector .invokeMethod (testCase ,
76
- request ,
77
- RequestBody .fromString (shapeModelReflector .getStreamingMemberValue ()));
78
- }
79
- Assert .fail ("Expected SDK client to intercept and record request before transmission." );
80
- } catch (InvocationTargetException e ) {
81
- if (e .getTargetException () instanceof StopExecutionException ) {
82
- SdkHttpRequest recordedRequest = recordingInterceptor .getRequest ();
83
- testCase .getThen ().getMarshallingAssertion ().assertMatches (getLoggedRequest ());
84
- } else {
85
- throw e ;
86
- }
66
+ if (!model .getShapes ().get (testCase .getWhen ().getOperationName () + "Request" ).isHasStreamingMember ()) {
67
+ clientReflector .invokeMethod (testCase , request );
68
+ } else {
69
+ clientReflector .invokeMethod (testCase ,
70
+ request ,
71
+ RequestBody .fromString (shapeModelReflector .getStreamingMemberValue ()));
87
72
}
73
+ testCase .getThen ().getMarshallingAssertion ()
74
+ .assertMatches (localhostOnlyForWiremockInterceptor .getLoggedRequestWithOriginalHost ());
75
+ }
76
+
77
+ /**
78
+ * Reset wire mock and re-configure stubbing.
79
+ */
80
+ private void resetWireMock () {
81
+ WireMock .reset ();
82
+ // Stub to return 200 for all requests
83
+ ResponseDefinitionBuilder responseDefBuilder = aResponse ().withStatus (200 );
84
+ // XML Unmarshallers expect at least one level in the XML document.
85
+ if (model .getMetadata ().isXmlProtocol ()) {
86
+ responseDefBuilder .withBody ("<foo></foo>" );
87
+ }
88
+ stubFor (any (urlMatching (".*" )).willReturn (responseDefBuilder ));
88
89
}
89
90
90
91
private AwsRequest createRequest (TestCase testCase , ShapeModelReflector shapeModelReflector ) {
91
92
return ((AwsRequest ) shapeModelReflector .createShapeObject ())
92
93
.toBuilder ()
93
94
.overrideConfiguration (requestConfiguration -> requestConfiguration
94
95
.addPlugin (config -> {
95
- config .overrideConfiguration (c -> c .addExecutionInterceptor (recordingInterceptor ));
96
+ config .overrideConfiguration (c -> c .addExecutionInterceptor (localhostOnlyForWiremockInterceptor ));
96
97
97
98
if (StringUtils .isNotBlank (testCase .getGiven ().getHost ())) {
98
99
config .endpointOverride (URI .create ("https://" + testCase .getGiven ().getHost ()));
@@ -115,26 +116,49 @@ private String getOperationRequestClassName(String operationName) {
115
116
return operationName + "Request" ;
116
117
}
117
118
118
- private static final class RequestRecordingInterceptor implements ExecutionInterceptor {
119
- private SdkHttpRequest request ;
119
+ /**
120
+ * Wiremock requires that requests use "localhost" as the host - any prefixes such as "foo.localhost" will
121
+ * result in a DNS lookup that will fail. This interceptor modifies the request to ensure this and captures
122
+ * the original host.
123
+ */
124
+ private static final class LocalhostOnlyForWiremockInterceptor implements ExecutionInterceptor {
125
+ private String originalHost ;
126
+ private String originalProtocol ;
127
+ private int originalPort ;
120
128
121
129
@ Override
122
- public void beforeTransmission (Context .BeforeTransmission context , ExecutionAttributes executionAttributes ) {
123
- request = context .httpRequest ();
124
-
125
- // Log or record the request here
126
- System . out . println ( "Recording Request:" );
127
- System . out . println ( "HTTP Method: " + request . method ());
128
- System . out . println ( "Endpoint: " + request . getUri ());
129
- System . out . println ( "Headers: " + request . headers ());
130
-
131
- throw new StopExecutionException ();
130
+ public SdkHttpRequest modifyHttpRequest (Context .ModifyHttpRequest context , ExecutionAttributes executionAttributes ) {
131
+ originalHost = context .httpRequest (). host ();
132
+ originalProtocol = context . httpRequest (). protocol ();
133
+ originalPort = context . httpRequest (). port ();
134
+
135
+ return context . httpRequest (). toBuilder ()
136
+ . host ( "localhost" )
137
+ . port ( WireMockUtils . port ())
138
+ . protocol ( "http" )
139
+ . build ();
132
140
}
133
141
134
- public SdkHttpRequest getRequest () {
135
- return request ;
142
+ /**
143
+ * @return LoggedRequest that wire mock captured modified with the original host captured by this
144
+ * interceptor.
145
+ */
146
+ public LoggedRequest getLoggedRequestWithOriginalHost () {
147
+ List <LoggedRequest > requests = WireMockUtils .findAllLoggedRequests ();
148
+ assertEquals (1 , requests .size ());
149
+ LoggedRequest loggedRequest = requests .get (0 );
150
+ return new LoggedRequest (
151
+ loggedRequest .getUrl (),
152
+ originalProtocol + "://" + originalHost + ":" + originalPort ,
153
+ loggedRequest .getMethod (),
154
+ loggedRequest .getClientIp (),
155
+ loggedRequest .getHeaders (),
156
+ loggedRequest .getCookies (),
157
+ loggedRequest .isBrowserProxyRequest (),
158
+ loggedRequest .getLoggedDate (),
159
+ loggedRequest .getBody (),
160
+ loggedRequest .getParts ()
161
+ );
136
162
}
137
163
}
138
-
139
- private static class StopExecutionException extends RuntimeException {}
140
164
}
0 commit comments