diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshalledErrorAssertion.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshalledErrorAssertion.java new file mode 100644 index 000000000000..763e774a930f --- /dev/null +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshalledErrorAssertion.java @@ -0,0 +1,61 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.protocol.asserts.unmarshalling; + + +import static org.junit.Assert.fail; +import static org.unitils.reflectionassert.ReflectionAssert.assertReflectionEquals; + +import com.fasterxml.jackson.databind.JsonNode; +import java.lang.reflect.Field; +import org.junit.Assert; +import software.amazon.awssdk.core.exception.SdkServiceException; +import software.amazon.awssdk.protocol.reflect.ShapeModelReflector; + +public class UnmarshalledErrorAssertion extends UnmarshallingAssertion { + private final JsonNode expectedError; + + public UnmarshalledErrorAssertion(JsonNode expectedError) { + this.expectedError = expectedError; + } + + @Override + protected void doAssert(UnmarshallingTestContext context, Object actual) throws Exception { + if (!(actual instanceof SdkServiceException)) { + fail("Expected unmarshalled object to be an instance of SdkServiceException"); + } + SdkServiceException actualException = (SdkServiceException) actual; + SdkServiceException expectedException = createExpectedResult(context); + for (Field field : expectedException.getClass().getDeclaredFields()) { + assertFieldEquals(field, actualException, expectedException); + } + + if (expectedException.getMessage() != null) { + Assert.assertTrue(actualException.getMessage().startsWith(expectedException.getMessage())); + } + } + + private SdkServiceException createExpectedResult(UnmarshallingTestContext context) { + return (SdkServiceException) new ShapeModelReflector(context.getModel(), context.getErrorName() + "Exception", + this.expectedError).createShapeObject(); + } + + private void assertFieldEquals(Field field, Object actual, Object expectedResult) throws + Exception { + field.setAccessible(true); + assertReflectionEquals(field.get(expectedResult), field.get(actual)); + } +} \ No newline at end of file diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshallingTestContext.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshallingTestContext.java index ee5d1e22987e..7ad69dd7828d 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshallingTestContext.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/asserts/unmarshalling/UnmarshallingTestContext.java @@ -26,6 +26,7 @@ public class UnmarshallingTestContext { private IntermediateModel model; private String operationName; private String streamedResponse; + private String errorName; public UnmarshallingTestContext withModel(IntermediateModel model) { this.model = model; @@ -58,4 +59,12 @@ public String getStreamedResponse() { return streamedResponse; } + public UnmarshallingTestContext withErrorName(String errorName) { + this.errorName = errorName; + return this; + } + + public String getErrorName() { + return errorName; + } } diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/Then.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/Then.java index 7d7e934cf030..0f6bc1268bd1 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/Then.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/Then.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.databind.JsonNode; import software.amazon.awssdk.protocol.asserts.marshalling.MarshallingAssertion; import software.amazon.awssdk.protocol.asserts.marshalling.SerializedAs; +import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshalledErrorAssertion; import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshalledResultAssertion; import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshallingAssertion; @@ -27,12 +28,14 @@ public class Then { private final MarshallingAssertion serializedAs; private final UnmarshallingAssertion deserializedAs; + private final UnmarshallingAssertion errorDeserializedAs; @JsonCreator public Then(@JsonProperty("serializedAs") SerializedAs serializedAs, @JsonProperty("deserializedAs") JsonNode deserializedAs) { this.serializedAs = serializedAs; this.deserializedAs = new UnmarshalledResultAssertion(deserializedAs); + this.errorDeserializedAs = new UnmarshalledErrorAssertion(deserializedAs); } /** @@ -49,4 +52,11 @@ public UnmarshallingAssertion getUnmarshallingAssertion() { return deserializedAs; } + /** + * + * @return The assertion object to use for error unmarshalling tests + */ + public UnmarshallingAssertion getErrorUnmarshallingAssertion() { + return errorDeserializedAs; + } } diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/When.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/When.java index 9611a8ba42fa..1db3442a364c 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/When.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/When.java @@ -26,6 +26,9 @@ public class When { @JsonProperty(value = "operation") private String operationName; + @JsonProperty(value = "error") + private String errorName; + public WhenAction getAction() { return action; } @@ -41,4 +44,12 @@ public String getOperationName() { public void setOperationName(String operationName) { this.operationName = operationName; } + + public void setErrorName(String errorName) { + this.errorName = errorName; + } + + public String getErrorName() { + return errorName; + } } diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/WhenAction.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/WhenAction.java index 0304b2e2a741..cf6317d4700c 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/WhenAction.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/model/WhenAction.java @@ -17,7 +17,8 @@ public enum WhenAction { MARSHALL("marshall"), - UNMARSHALL("unmarshall"); + UNMARSHALL("unmarshall"), + ERROR_UNMARSHALL("errorUnmarshall"); private final String action; @@ -31,6 +32,8 @@ public static WhenAction fromValue(String action) { return MARSHALL; case "unmarshall": return UNMARSHALL; + case "errorUnmarshall": + return ERROR_UNMARSHALL; default: throw new IllegalArgumentException("Unsupported test action " + action); } diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/reflect/ShapeModelReflector.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/reflect/ShapeModelReflector.java index 377f7b4e0863..c39a12ebf0e8 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/reflect/ShapeModelReflector.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/reflect/ShapeModelReflector.java @@ -110,6 +110,14 @@ private void initializeFields(ShapeModel structureShape, JsonNode input, Iterator fieldNames = input.fieldNames(); while (fieldNames.hasNext()) { String memberName = fieldNames.next(); + // error structures have special case handling of "message" + if (structureShape.getErrorCode() != null && memberName.equalsIgnoreCase("message")) { + Method setter = shapeObject.getClass().getMethod("message", String.class); + setter.setAccessible(true); + setter.invoke(shapeObject, input.get(memberName).asText()); + continue; + } + MemberModel memberModel = structureShape.getMemberByC2jName(memberName); if (memberModel == null) { throw new IllegalArgumentException("Member " + memberName + " was not found in the " + diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/ProtocolTestRunner.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/ProtocolTestRunner.java index dde66e09e251..6c86766b1796 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/ProtocolTestRunner.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/ProtocolTestRunner.java @@ -72,6 +72,7 @@ public void runTest(TestCase testCase) throws Exception { marshallingTestRunner.runTest(testCase); break; case UNMARSHALL: + case ERROR_UNMARSHALL: unmarshallingTestRunner.runTest(testCase); break; default: diff --git a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/UnmarshallingTestRunner.java b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/UnmarshallingTestRunner.java index 6ee3b88cb010..967b4823bd90 100644 --- a/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/UnmarshallingTestRunner.java +++ b/test/protocol-tests-core/src/main/java/software/amazon/awssdk/protocol/runners/UnmarshallingTestRunner.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; import com.github.tomakehurst.wiremock.client.WireMock; +import java.lang.reflect.InvocationTargetException; import java.util.Base64; import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; import software.amazon.awssdk.codegen.model.intermediate.Metadata; @@ -52,6 +53,21 @@ class UnmarshallingTestRunner { void runTest(TestCase testCase) throws Exception { resetWireMock(testCase.getGiven().getResponse()); + + switch (testCase.getWhen().getAction()) { + case UNMARSHALL: + runUnmarshallTest(testCase); + break; + case ERROR_UNMARSHALL: + runErrorUnmarshallTest(testCase); + break; + default: + throw new IllegalArgumentException("UnmarshallingTestRunner unable to run test case for action " + + testCase.getWhen().getAction()); + } + } + + private void runUnmarshallTest(TestCase testCase) throws Exception { String operationName = testCase.getWhen().getOperationName(); ShapeModelReflector shapeModelReflector = createShapeModelReflector(testCase); if (!hasStreamingMember(operationName)) { @@ -60,12 +76,32 @@ void runTest(TestCase testCase) throws Exception { } else { CapturingResponseTransformer responseHandler = new CapturingResponseTransformer(); Object actualResult = clientReflector - .invokeStreamingMethod(testCase, shapeModelReflector.createShapeObject(), responseHandler); + .invokeStreamingMethod(testCase, shapeModelReflector.createShapeObject(), responseHandler); testCase.getThen().getUnmarshallingAssertion() .assertMatches(createContext(operationName, responseHandler.captured), actualResult); } } + private void runErrorUnmarshallTest(TestCase testCase) throws Exception { + String operationName = testCase.getWhen().getOperationName(); + ShapeModelReflector shapeModelReflector = createShapeModelReflector(testCase); + try { + clientReflector.invokeMethod(testCase, shapeModelReflector.createShapeObject()); + throw new IllegalStateException("Test case expected client to throw error"); + } catch (InvocationTargetException t) { + String errorName = testCase.getWhen().getErrorName(); + testCase.getThen().getErrorUnmarshallingAssertion().assertMatches( + createErrorContext(operationName, errorName), t.getCause()); + } + } + + private UnmarshallingTestContext createErrorContext(String operationName, String errorName) { + return new UnmarshallingTestContext() + .withModel(model) + .withOperationName(operationName) + .withErrorName(errorName); + } + /** * {@link ResponseTransformer} that simply captures all the content as a String so we * can compare it with the expected in diff --git a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/smithy-rpcv2-output.json b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/smithy-rpcv2-output.json index bc82842c02d5..c7523d9776f6 100644 --- a/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/smithy-rpcv2-output.json +++ b/test/protocol-tests-core/src/main/resources/software/amazon/awssdk/protocol/suites/cases/smithy-rpcv2-output.json @@ -663,5 +663,76 @@ "then": { "deserializedAs": {} } + }, + { + "description": "Parses simple RpcV2 Cbor errors.", + "given": { + "response": { + "status_code": 400, + "headers": { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + }, + "binaryBody": "v2ZfX3R5cGV4LnNtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNJbnZhbGlkR3JlZXRpbmdnTWVzc2FnZWJIaf8=" + } + }, + "when": { + "action": "errorUnmarshall", + "operation": "GreetingWithErrors", + "error": "InvalidGreeting" + }, + "then": { + "deserializedAs": { + "Message": "Hi" + } + } + }, + { + "description": "Parses a complex error with no message member", + "given": { + "response": { + "status_code": 400, + "headers": { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + }, + "binaryBody": "v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3JoVG9wTGV2ZWxpVG9wIGxldmVsZk5lc3RlZL9jRm9vY2Jhcv//" + } + }, + "when": { + "action": "errorUnmarshall", + "operation": "GreetingWithErrors", + "error": "ComplexError" + }, + "then": { + "deserializedAs": { + "TopLevel": "Top level", + "Nested": { + "Foo": "bar" + } + } + } + }, + { + "description": "Parses an empty complex error", + "given": { + "response": { + "status_code": 400, + "headers": { + "smithy-protocol": "rpc-v2-cbor", + "Content-Type": "application/cbor" + }, + "binaryBody": "v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3L/" + } + }, + "when": { + "action": "errorUnmarshall", + "operation": "GreetingWithErrors", + "error": "ComplexError" + }, + "then": { + "deserializedAs": { + } + } } ] \ No newline at end of file diff --git a/test/protocol-tests/src/main/resources/codegen-resources/sdkrpcv2/service-2.json b/test/protocol-tests/src/main/resources/codegen-resources/sdkrpcv2/service-2.json index 62357f2cf5f7..4d3c2ea4a0d7 100644 --- a/test/protocol-tests/src/main/resources/codegen-resources/sdkrpcv2/service-2.json +++ b/test/protocol-tests/src/main/resources/codegen-resources/sdkrpcv2/service-2.json @@ -95,6 +95,19 @@ "method": "POST", "requestUri": "/" } + }, + "GreetingWithErrors":{ + "name":"GreetingWithErrors", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "output":{"shape":"GreetingWithErrorsOutput"}, + "errors":[ + {"shape":"ComplexError"}, + {"shape":"InvalidGreeting"} + ], + "idempotent":true } }, "shapes": { @@ -679,6 +692,33 @@ "shape": "AllTypesUnionStructure" } } + }, + "GreetingWithErrorsOutput":{ + "type":"structure", + "members":{ + "greeting":{"shape":"String"} + } + }, + "ComplexError":{ + "type":"structure", + "members":{ + "TopLevel":{"shape":"String"}, + "Nested":{"shape":"ComplexNestedErrorData"} + }, + "exception":true + }, + "ComplexNestedErrorData":{ + "type":"structure", + "members":{ + "Foo":{"shape":"String"} + } + }, + "InvalidGreeting":{ + "type":"structure", + "members":{ + "Message":{"shape":"String"} + }, + "exception":true } } }