Skip to content

Add support for protocol test error cases + add cbor error cases #6207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.protocol.asserts.unmarshalling;


import static org.junit.Assert.fail;
import static org.unitils.reflectionassert.ReflectionAssert.assertReflectionEquals;

import com.fasterxml.jackson.databind.JsonNode;
import java.lang.reflect.Field;
import org.junit.Assert;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.protocol.reflect.ShapeModelReflector;

public class UnmarshalledErrorAssertion extends UnmarshallingAssertion {
private final JsonNode expectedError;

public UnmarshalledErrorAssertion(JsonNode expectedError) {
this.expectedError = expectedError;
}

@Override
protected void doAssert(UnmarshallingTestContext context, Object actual) throws Exception {
if (!(actual instanceof SdkServiceException)) {
fail("Expected unmarshalled object to be an instance of SdkServiceException");
}
SdkServiceException actualException = (SdkServiceException) actual;
SdkServiceException expectedException = createExpectedResult(context);
for (Field field : expectedException.getClass().getDeclaredFields()) {
assertFieldEquals(field, actualException, expectedException);
}

if (expectedException.getMessage() != null) {
Assert.assertTrue(actualException.getMessage().startsWith(expectedException.getMessage()));
}
}

private SdkServiceException createExpectedResult(UnmarshallingTestContext context) {
return (SdkServiceException) new ShapeModelReflector(context.getModel(), context.getErrorName() + "Exception",
this.expectedError).createShapeObject();
}

private void assertFieldEquals(Field field, Object actual, Object expectedResult) throws
Exception {
field.setAccessible(true);
assertReflectionEquals(field.get(expectedResult), field.get(actual));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class UnmarshallingTestContext {
private IntermediateModel model;
private String operationName;
private String streamedResponse;
private String errorName;

public UnmarshallingTestContext withModel(IntermediateModel model) {
this.model = model;
Expand Down Expand Up @@ -58,4 +59,12 @@ public String getStreamedResponse() {
return streamedResponse;
}

public UnmarshallingTestContext withErrorName(String errorName) {
this.errorName = errorName;
return this;
}

public String getErrorName() {
return errorName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@
import com.fasterxml.jackson.databind.JsonNode;
import software.amazon.awssdk.protocol.asserts.marshalling.MarshallingAssertion;
import software.amazon.awssdk.protocol.asserts.marshalling.SerializedAs;
import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshalledErrorAssertion;
import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshalledResultAssertion;
import software.amazon.awssdk.protocol.asserts.unmarshalling.UnmarshallingAssertion;

public class Then {

private final MarshallingAssertion serializedAs;
private final UnmarshallingAssertion deserializedAs;
private final UnmarshallingAssertion errorDeserializedAs;

@JsonCreator
public Then(@JsonProperty("serializedAs") SerializedAs serializedAs,
@JsonProperty("deserializedAs") JsonNode deserializedAs) {
this.serializedAs = serializedAs;
this.deserializedAs = new UnmarshalledResultAssertion(deserializedAs);
this.errorDeserializedAs = new UnmarshalledErrorAssertion(deserializedAs);
}

/**
Expand All @@ -49,4 +52,11 @@ public UnmarshallingAssertion getUnmarshallingAssertion() {
return deserializedAs;
}

/**
*
* @return The assertion object to use for error unmarshalling tests
*/
public UnmarshallingAssertion getErrorUnmarshallingAssertion() {
return errorDeserializedAs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public class When {
@JsonProperty(value = "operation")
private String operationName;

@JsonProperty(value = "error")
private String errorName;

public WhenAction getAction() {
return action;
}
Expand All @@ -41,4 +44,12 @@ public String getOperationName() {
public void setOperationName(String operationName) {
this.operationName = operationName;
}

public void setErrorName(String errorName) {
this.errorName = errorName;
}

public String getErrorName() {
return errorName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

public enum WhenAction {
MARSHALL("marshall"),
UNMARSHALL("unmarshall");
UNMARSHALL("unmarshall"),
ERROR_UNMARSHALL("errorUnmarshall");

private final String action;

Expand All @@ -31,6 +32,8 @@ public static WhenAction fromValue(String action) {
return MARSHALL;
case "unmarshall":
return UNMARSHALL;
case "errorUnmarshall":
return ERROR_UNMARSHALL;
default:
throw new IllegalArgumentException("Unsupported test action " + action);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ private void initializeFields(ShapeModel structureShape, JsonNode input,
Iterator<String> fieldNames = input.fieldNames();
while (fieldNames.hasNext()) {
String memberName = fieldNames.next();
// error structures have special case handling of "message"
if (structureShape.getErrorCode() != null && memberName.equalsIgnoreCase("message")) {
Method setter = shapeObject.getClass().getMethod("message", String.class);
setter.setAccessible(true);
setter.invoke(shapeObject, input.get(memberName).asText());
continue;
}

MemberModel memberModel = structureShape.getMemberByC2jName(memberName);
if (memberModel == null) {
throw new IllegalArgumentException("Member " + memberName + " was not found in the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public void runTest(TestCase testCase) throws Exception {
marshallingTestRunner.runTest(testCase);
break;
case UNMARSHALL:
case ERROR_UNMARSHALL:
unmarshallingTestRunner.runTest(testCase);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder;
import com.github.tomakehurst.wiremock.client.WireMock;
import java.lang.reflect.InvocationTargetException;
import java.util.Base64;
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.Metadata;
Expand Down Expand Up @@ -52,6 +53,21 @@ class UnmarshallingTestRunner {

void runTest(TestCase testCase) throws Exception {
resetWireMock(testCase.getGiven().getResponse());

switch (testCase.getWhen().getAction()) {
case UNMARSHALL:
runUnmarshallTest(testCase);
break;
case ERROR_UNMARSHALL:
runErrorUnmarshallTest(testCase);
break;
default:
throw new IllegalArgumentException("UnmarshallingTestRunner unable to run test case for action "
+ testCase.getWhen().getAction());
}
}

private void runUnmarshallTest(TestCase testCase) throws Exception {
String operationName = testCase.getWhen().getOperationName();
ShapeModelReflector shapeModelReflector = createShapeModelReflector(testCase);
if (!hasStreamingMember(operationName)) {
Expand All @@ -60,12 +76,32 @@ void runTest(TestCase testCase) throws Exception {
} else {
CapturingResponseTransformer responseHandler = new CapturingResponseTransformer();
Object actualResult = clientReflector
.invokeStreamingMethod(testCase, shapeModelReflector.createShapeObject(), responseHandler);
.invokeStreamingMethod(testCase, shapeModelReflector.createShapeObject(), responseHandler);
testCase.getThen().getUnmarshallingAssertion()
.assertMatches(createContext(operationName, responseHandler.captured), actualResult);
}
}

private void runErrorUnmarshallTest(TestCase testCase) throws Exception {
String operationName = testCase.getWhen().getOperationName();
ShapeModelReflector shapeModelReflector = createShapeModelReflector(testCase);
try {
clientReflector.invokeMethod(testCase, shapeModelReflector.createShapeObject());
throw new IllegalStateException("Test case expected client to throw error");
} catch (InvocationTargetException t) {
String errorName = testCase.getWhen().getErrorName();
testCase.getThen().getErrorUnmarshallingAssertion().assertMatches(
createErrorContext(operationName, errorName), t.getCause());
}
}

private UnmarshallingTestContext createErrorContext(String operationName, String errorName) {
return new UnmarshallingTestContext()
.withModel(model)
.withOperationName(operationName)
.withErrorName(errorName);
}

/**
* {@link ResponseTransformer} that simply captures all the content as a String so we
* can compare it with the expected in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,5 +663,76 @@
"then": {
"deserializedAs": {}
}
},
{
"description": "Parses simple RpcV2 Cbor errors.",
"given": {
"response": {
"status_code": 400,
"headers": {
"smithy-protocol": "rpc-v2-cbor",
"Content-Type": "application/cbor"
},
"binaryBody": "v2ZfX3R5cGV4LnNtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNJbnZhbGlkR3JlZXRpbmdnTWVzc2FnZWJIaf8="
}
},
"when": {
"action": "errorUnmarshall",
"operation": "GreetingWithErrors",
"error": "InvalidGreeting"
},
"then": {
"deserializedAs": {
"Message": "Hi"
}
}
},
{
"description": "Parses a complex error with no message member",
"given": {
"response": {
"status_code": 400,
"headers": {
"smithy-protocol": "rpc-v2-cbor",
"Content-Type": "application/cbor"
},
"binaryBody": "v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3JoVG9wTGV2ZWxpVG9wIGxldmVsZk5lc3RlZL9jRm9vY2Jhcv//"
}
},
"when": {
"action": "errorUnmarshall",
"operation": "GreetingWithErrors",
"error": "ComplexError"
},
"then": {
"deserializedAs": {
"TopLevel": "Top level",
"Nested": {
"Foo": "bar"
}
}
}
},
{
"description": "Parses an empty complex error",
"given": {
"response": {
"status_code": 400,
"headers": {
"smithy-protocol": "rpc-v2-cbor",
"Content-Type": "application/cbor"
},
"binaryBody": "v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3L/"
}
},
"when": {
"action": "errorUnmarshall",
"operation": "GreetingWithErrors",
"error": "ComplexError"
},
"then": {
"deserializedAs": {
}
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@
"method": "POST",
"requestUri": "/"
}
},
"GreetingWithErrors":{
"name":"GreetingWithErrors",
"http":{
"method":"POST",
"requestUri":"/"
},
"output":{"shape":"GreetingWithErrorsOutput"},
"errors":[
{"shape":"ComplexError"},
{"shape":"InvalidGreeting"}
],
"idempotent":true
}
},
"shapes": {
Expand Down Expand Up @@ -679,6 +692,33 @@
"shape": "AllTypesUnionStructure"
}
}
},
"GreetingWithErrorsOutput":{
"type":"structure",
"members":{
"greeting":{"shape":"String"}
}
},
"ComplexError":{
"type":"structure",
"members":{
"TopLevel":{"shape":"String"},
"Nested":{"shape":"ComplexNestedErrorData"}
},
"exception":true
},
"ComplexNestedErrorData":{
"type":"structure",
"members":{
"Foo":{"shape":"String"}
}
},
"InvalidGreeting":{
"type":"structure",
"members":{
"Message":{"shape":"String"}
},
"exception":true
}
}
}