Skip to content

Implement CborErrorResponseUnmarshaller with proper error type mapping #3947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Collections.Generic;
using System.Formats.Cbor;
using System.IO;
using System.Net;

namespace Amazon.Extensions.CborProtocol.Internal.Transform
{
Expand All @@ -37,8 +38,75 @@ public class CborErrorResponseUnmarshaller : ICborUnmarshaller<ErrorResponse, Cb
/// <returns>An <c>ErrorResponse</c> object.</returns>
public ErrorResponse Unmarshall(CborUnmarshallerContext context)
{
// Placeholder until CBOR exception implementation.
return null;
var errorType = ErrorType.Unknown;

if (context.ResponseData.StatusCode == HttpStatusCode.BadRequest)
{
errorType = ErrorType.Sender;
}
else if (context.ResponseData.StatusCode == HttpStatusCode.InternalServerError)
{
errorType = ErrorType.Receiver;
}

var response = new ErrorResponse
{
Type = errorType,
StatusCode = context.ResponseData.StatusCode,
};

var reader = context.Reader;
reader.ReadStartMap();
while (reader.PeekState() != CborReaderState.EndMap)
{
string propertyName = reader.ReadTextString().ToLowerInvariant();
switch (propertyName)
{
case "__type":
{
context.AddPathSegment("__type");
var unmarshaller = CborStringUnmarshaller.Instance;
var type = unmarshaller.Unmarshall(context);
response.Code = SanitizeErrorType(type);
context.PopPathSegment();
break;
}
case "message":
{
context.AddPathSegment("message");
var unmarshaller = CborStringUnmarshaller.Instance;
response.Message = unmarshaller.Unmarshall(context);
context.PopPathSegment();
break;
}
default:
reader.SkipValue();
break;
}
}
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CBOR map reading logic doesn't call reader.ReadEndMap() after the while loop, which could leave the reader in an inconsistent state. Consider adding reader.ReadEndMap() after the while loop.

Suggested change
}
}
reader.ReadEndMap();

Copilot uses AI. Check for mistakes.

Copy link
Member Author

@muhammad-othman muhammad-othman Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skipped that ReadEndMap() since in the next step we re-read the response anyway to generate the typed exception and reset the reader anyway and didn’t want to add an unnecessary step.


if (context.ResponseData.IsHeaderPresent(HeaderKeys.RequestIdHeader))
{
response.RequestId = context.ResponseData.GetHeaderValue(HeaderKeys.RequestIdHeader);
}

return response;
}

/// <summary>
/// Extracts the error type from a Smithy shape identifier string.
/// The input is expected to be in the format "namespace#ErrorType[:additionalInfo]".
/// Returns the error type portion (e.g., "ErrorType").
/// </summary>
private string SanitizeErrorType(string type)
{
int start = type.IndexOf('#');
start = start == -1 ? 0 : start + 1;

int end = type.IndexOf(':', start);
end = end == -1 ? type.Length : end;

Comment on lines +102 to +108
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential ArgumentOutOfRangeException if 'type' parameter is null or if the calculated substring indices are invalid. The method should validate the input parameter and handle edge cases where start >= end.

Suggested change
{
int start = type.IndexOf('#');
start = start == -1 ? 0 : start + 1;
int end = type.IndexOf(':', start);
end = end == -1 ? type.Length : end;
{
if (string.IsNullOrEmpty(type))
{
return string.Empty;
}
int start = type.IndexOf('#');
start = start == -1 ? 0 : start + 1;
int end = type.IndexOf(':', start);
end = end == -1 ? type.Length : end;
// Validate indices
if (start < 0 || end > type.Length || start >= end)
{
return type.Trim();
}

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to include these checks since this means we received a malformed response from the service.

return type.Substring(start, end - start).Trim();
}

private static CborErrorResponseUnmarshaller instance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ private void generateErrorResponseTests(OperationShape operation, OperationIndex
for (StructureShape error : index.getErrors(operation, service)) {
error.getTrait(HttpResponseTestsTrait.class).ifPresent(trait -> {
for (HttpResponseTestCase httpResponseTestCase : trait.getTestCasesFor(AppliesTo.CLIENT)) {
if(!ProtocolTestCustomizations.TestsToSkip.contains(httpResponseTestCase.getId())&&
!trait.getTestCasesFor(AppliesTo.CLIENT).getFirst().getProtocol().getName().toLowerCase().contains("cbor")) // Skip CBOR response tests until the unmarshallers are ready
if(!ProtocolTestCustomizations.TestsToSkip.contains(httpResponseTestCase.getId()))
generateErrorResponseTest(operation, error, httpResponseTestCase);
}
});
Expand Down
14 changes: 14 additions & 0 deletions generator/ServiceClientGeneratorLib/ExceptionShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ public string Code
}
}

/// <summary>
/// Returns the original shape name of the exception specified in the json model.
/// This is used to find the exception type for CBOR as the exception response contains
/// the Shape ID rather than the error code.
/// https://smithy.io/2.0/additional-specs/protocols/smithy-rpc-v2.html#operation-error-serialization
/// </summary>
public string ShapeOriginalName
{
get
{
return base.Name;
}
}

/// <summary>
/// Determines if the exception is marked retryable
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,47 @@ public override AmazonServiceException UnmarshallException(CborUnmarshallerConte
this.Write(" if (errorResponse.Code != null && errorResponse.Code.Equals(\"");

#line 179 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(exception.Code));
this.Write(this.ToStringHelper.ToStringWithCulture(exception.ShapeOriginalName));

#line default
#line hidden
this.Write("\"))\r\n {\r\n return ");
this.Write("\"))\r\n {\r\n");

#line 181 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

if (exception.ShapeOriginalName != exception.Code)
{


#line default
#line hidden
this.Write(" errorResponse.Code = \"");

#line 185 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(exception.Code));

#line default
#line hidden
this.Write("\";\r\n");

#line 186 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}


#line default
#line hidden
this.Write(" return ");

#line 189 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(exception.Name));

#line default
#line hidden
this.Write("Unmarshaller.Instance.Unmarshall(contextCopy, errorResponse);\r\n }\r" +
"\n");

#line 183 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 191 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}

Expand All @@ -421,7 +447,7 @@ public override AmazonServiceException UnmarshallException(CborUnmarshallerConte
#line hidden
this.Write(" }\r\n");

#line 187 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 195 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

if (this.Config.ServiceModel.IsAwsQueryCompatible)
{
Expand All @@ -432,15 +458,15 @@ public override AmazonServiceException UnmarshallException(CborUnmarshallerConte
#line hidden
this.Write(" return new ");

#line 192 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 200 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(this.BaseException));

#line default
#line hidden
this.Write("(errorResponse.Message, errorResponse.InnerException, errorType, errorCode, error" +
"Response.RequestId, errorResponse.StatusCode);\r\n");

#line 193 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 201 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}
else
Expand All @@ -451,15 +477,15 @@ public override AmazonServiceException UnmarshallException(CborUnmarshallerConte
#line hidden
this.Write(" return new ");

#line 198 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 206 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
this.Write(this.ToStringHelper.ToStringWithCulture(this.BaseException));

#line default
#line hidden
this.Write("(errorResponse.Message, errorResponse.InnerException, errorResponse.Type, errorRe" +
"sponse.Code, errorResponse.RequestId, errorResponse.StatusCode);\r\n");

#line 199 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 207 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}

Expand All @@ -468,7 +494,7 @@ public override AmazonServiceException UnmarshallException(CborUnmarshallerConte
#line hidden
this.Write(" }\r\n\r\n");

#line 204 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 212 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

if (payload != null && payload.Shape.IsStreaming)
{
Expand All @@ -489,7 +515,7 @@ public override bool HasStreamingProperty

");

#line 219 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 227 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}
this.AddResponseSingletonMethod();
Expand All @@ -498,7 +524,7 @@ public override bool HasStreamingProperty
#line default
#line hidden

#line 223 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 231 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

if(isEventStreamOutput)
{
Expand All @@ -522,7 +548,7 @@ protected override bool ShouldReadEntireResponse(IWebResponseData response, bool
public override bool HasStreamingProperty => true;
");

#line 241 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"
#line 249 "C:\repos\aws-sdk-net-v4\generator\ServiceClientGeneratorLib\Generators\Marshallers\CborResponseUnmarshaller.tt"

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,16 @@ namespace <#=this.Config.Namespace #>.Model.Internal.MarshallTransformations
foreach (var exception in this.Operation.Exceptions)
{
#>
if (errorResponse.Code != null && errorResponse.Code.Equals("<#=exception.Code #>"))
if (errorResponse.Code != null && errorResponse.Code.Equals("<#=exception.ShapeOriginalName #>"))
{
<#
if (exception.ShapeOriginalName != exception.Code)
{
#>
errorResponse.Code = "<#=exception.Code#>";
<#
}
#>
return <#=exception.Name#>Unmarshaller.Instance.Unmarshall(contextCopy, errorResponse);
}
<#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,82 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Text;

namespace AWSSDK.ProtocolTests.RpcV2Protocol
{
[TestClass]
public class GreetingWithErrors
{
/// <summary>
/// Parses simple RpcV2 Cbor errors
/// </summary>
[TestMethod]
[TestCategory("ProtocolTest")]
[TestCategory("ErrorTest")]
[TestCategory("RpcV2Protocol")]
public void RpcV2CborInvalidGreetingErrorErrorResponse()
{
// Arrange
var webResponseData = new WebResponseData();
webResponseData.StatusCode = (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400);
webResponseData.Headers["Content-Type"] = "application/cbor";
webResponseData.Headers["smithy-protocol"] = "rpc-v2-cbor";
byte[] bytes = Convert.FromBase64String("v2ZfX3R5cGV4LnNtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNJbnZhbGlkR3JlZXRpbmdnTWVzc2FnZWJIaf8=");
var stream = new MemoryStream(bytes);
var context = new CborUnmarshallerContext(stream,true,webResponseData);
// Act
var errorResponse = new GreetingWithErrorsResponseUnmarshaller().UnmarshallException(context, null, (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
// Assert
Assert.IsInstanceOfType(errorResponse, typeof(InvalidGreetingException));
Assert.AreEqual(errorResponse.StatusCode,(HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
}

/// <summary>
/// Parses a complex error with no message member
/// </summary>
[TestMethod]
[TestCategory("ProtocolTest")]
[TestCategory("ErrorTest")]
[TestCategory("RpcV2Protocol")]
public void RpcV2CborComplexErrorErrorResponse()
{
// Arrange
var webResponseData = new WebResponseData();
webResponseData.StatusCode = (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400);
webResponseData.Headers["Content-Type"] = "application/cbor";
webResponseData.Headers["smithy-protocol"] = "rpc-v2-cbor";
byte[] bytes = Convert.FromBase64String("v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3JoVG9wTGV2ZWxpVG9wIGxldmVsZk5lc3RlZL9jRm9vY2Jhcv//");
var stream = new MemoryStream(bytes);
var context = new CborUnmarshallerContext(stream,true,webResponseData);
// Act
var errorResponse = new GreetingWithErrorsResponseUnmarshaller().UnmarshallException(context, null, (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
// Assert
Assert.IsInstanceOfType(errorResponse, typeof(ComplexErrorException));
Assert.AreEqual(errorResponse.StatusCode,(HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
}

[TestMethod]
[TestCategory("ProtocolTest")]
[TestCategory("ErrorTest")]
[TestCategory("RpcV2Protocol")]
public void RpcV2CborEmptyComplexErrorErrorResponse()
{
// Arrange
var webResponseData = new WebResponseData();
webResponseData.StatusCode = (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400);
webResponseData.Headers["Content-Type"] = "application/cbor";
webResponseData.Headers["smithy-protocol"] = "rpc-v2-cbor";
byte[] bytes = Convert.FromBase64String("v2ZfX3R5cGV4K3NtaXRoeS5wcm90b2NvbHRlc3RzLnJwY3YyQ2JvciNDb21wbGV4RXJyb3L/");
var stream = new MemoryStream(bytes);
var context = new CborUnmarshallerContext(stream,true,webResponseData);
// Act
var errorResponse = new GreetingWithErrorsResponseUnmarshaller().UnmarshallException(context, null, (HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
// Assert
Assert.IsInstanceOfType(errorResponse, typeof(ComplexErrorException));
Assert.AreEqual(errorResponse.StatusCode,(HttpStatusCode)Enum.ToObject(typeof(HttpStatusCode), 400));
}

}
}