diff --git a/codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AddProtocolConfig.java b/codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AddProtocolConfig.java index e82e5e30b8837..a8b9bba8082c6 100644 --- a/codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AddProtocolConfig.java +++ b/codegen/smithy-aws-typescript-codegen/src/main/java/software/amazon/smithy/aws/typescript/codegen/AddProtocolConfig.java @@ -6,11 +6,13 @@ package software.amazon.smithy.aws.typescript.codegen; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Consumer; import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait; import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait; +import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait; import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait; import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait; import software.amazon.smithy.aws.traits.protocols.RestJson1Trait; @@ -18,6 +20,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.traits.XmlNamespaceTrait; +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait; import software.amazon.smithy.typescript.codegen.LanguageTarget; import software.amazon.smithy.typescript.codegen.TypeScriptSettings; import software.amazon.smithy.typescript.codegen.TypeScriptWriter; @@ -60,6 +63,13 @@ public void addConfigInterfaceFields( // by the smithy client config interface. } + @Override + public List runAfter() { + return List.of( + software.amazon.smithy.typescript.codegen.integration.AddProtocolConfig.class.getCanonicalName() + ); + } + @Override public Map> getRuntimeConfigWriters( TypeScriptSettings settings, @@ -76,6 +86,7 @@ public Map> getRuntimeConfigWriters( .getTrait(XmlNamespaceTrait.class) .map(XmlNamespaceTrait::getUri) .orElse(""); + String awsQueryCompat = settings.getService(model).hasTrait(AwsQueryCompatibleTrait.class) ? "true" : "false"; switch (target) { case SHARED: @@ -148,9 +159,15 @@ public Map> getRuntimeConfigWriters( "AwsJson1_0Protocol", null, AwsDependency.AWS_SDK_CORE, "/protocols"); writer.write( - "new AwsJson1_0Protocol({ defaultNamespace: $S, serviceTarget: $S })", + """ + new AwsJson1_0Protocol({ + defaultNamespace: $S, + serviceTarget: $S, + awsQueryCompatible: $L + })""", namespace, - rpcTarget + rpcTarget, + awsQueryCompat ); } ); @@ -161,9 +178,32 @@ public Map> getRuntimeConfigWriters( "AwsJson1_1Protocol", null, AwsDependency.AWS_SDK_CORE, "/protocols"); writer.write( - "new AwsJson1_1Protocol({ defaultNamespace: $S, serviceTarget: $S })", + """ + new AwsJson1_1Protocol({ + defaultNamespace: $S, + serviceTarget: $S, + awsQueryCompatible: $L + })""", + namespace, + rpcTarget, + awsQueryCompat + ); + } + ); + } else if (Objects.equals(settings.getProtocol(), Rpcv2CborTrait.ID)) { + return MapUtils.of( + "protocol", writer -> { + writer.addImportSubmodule( + "AwsSmithyRpcV2CborProtocol", null, + AwsDependency.AWS_SDK_CORE, "/protocols"); + writer.write( + """ + new AwsSmithyRpcV2CborProtocol({ + defaultNamespace: $S, + awsQueryCompatible: $L + })""", namespace, - rpcTarget + awsQueryCompat ); } ); diff --git a/packages/core/src/submodules/protocols/ProtocolLib.ts b/packages/core/src/submodules/protocols/ProtocolLib.ts new file mode 100644 index 0000000000000..d209622770f64 --- /dev/null +++ b/packages/core/src/submodules/protocols/ProtocolLib.ts @@ -0,0 +1,163 @@ +import { ErrorSchema, NormalizedSchema, TypeRegistry } from "@smithy/core/schema"; +import type { + BodyLengthCalculator, + HttpResponse as IHttpResponse, + MetadataBearer, + ResponseMetadata, + SerdeFunctions, +} from "@smithy/types"; +import { calculateBodyLength } from "@smithy/util-body-length-browser"; + +/** + * @internal + */ +type ErrorMetadataBearer = MetadataBearer & { + $response: IHttpResponse; + $fault: "client" | "server"; +}; + +/** + * Shared code for Protocols. + * + * @internal + */ +export class ProtocolLib { + /** + * @param body - to be inspected. + * @param serdeContext - this is a subset type but in practice is the client.config having a property called bodyLengthChecker. + * + * @returns content-length value for the body if possible. + * @throws Error and should be caught and handled if not possible to determine length. + */ + public calculateContentLength(body: any, serdeContext?: SerdeFunctions) { + const bodyLengthCalculator: BodyLengthCalculator = + ( + serdeContext as SerdeFunctions & { + bodyLengthChecker?: BodyLengthCalculator; + } + )?.bodyLengthChecker ?? calculateBodyLength; + return String(bodyLengthCalculator(body)); + } + + /** + * This is only for REST protocols. + * + * @param defaultContentType - of the protocol. + * @param inputSchema - schema for which to determine content type. + * + * @returns content-type header value or undefined when not applicable. + */ + public resolveRestContentType(defaultContentType: string, inputSchema: NormalizedSchema): string | undefined { + const members = inputSchema.getMemberSchemas(); + const httpPayloadMember = Object.values(members).find((m) => { + return !!m.getMergedTraits().httpPayload; + }); + + if (httpPayloadMember) { + const mediaType = httpPayloadMember.getMergedTraits().mediaType as string; + if (mediaType) { + return mediaType; + } else if (httpPayloadMember.isStringSchema()) { + return "text/plain"; + } else if (httpPayloadMember.isBlobSchema()) { + return "application/octet-stream"; + } else { + return defaultContentType; + } + } else if (!inputSchema.isUnitSchema()) { + const hasBody = Object.values(members).find((m) => { + const { httpQuery, httpQueryParams, httpHeader, httpLabel, httpPrefixHeaders } = m.getMergedTraits(); + return !httpQuery && !httpQueryParams && !httpHeader && !httpLabel && httpPrefixHeaders === void 0; + }); + if (hasBody) { + return defaultContentType; + } + } + } + + /** + * Shared code for finding error schema or throwing an unmodeled base error. + * @returns error schema and error metadata. + * + * @throws ServiceBaseException or generic Error if no error schema could be found. + */ + public async getErrorSchemaOrThrowBaseException( + errorIdentifier: string, + defaultNamespace: string, + response: IHttpResponse, + dataObject: any, + metadata: ResponseMetadata, + getErrorSchema?: (registry: TypeRegistry, errorName: string) => ErrorSchema + ): Promise<{ errorSchema: ErrorSchema; errorMetadata: ErrorMetadataBearer }> { + let namespace = defaultNamespace; + let errorName = errorIdentifier; + if (errorIdentifier.includes("#")) { + [namespace, errorName] = errorIdentifier.split("#"); + } + + const errorMetadata: ErrorMetadataBearer = { + $metadata: metadata, + $response: response, + $fault: response.statusCode < 500 ? ("client" as const) : ("server" as const), + }; + + const registry = TypeRegistry.for(namespace); + + try { + const errorSchema = getErrorSchema?.(registry, errorName) ?? (registry.getSchema(errorIdentifier) as ErrorSchema); + return { errorSchema, errorMetadata }; + } catch (e) { + if (dataObject.Message) { + dataObject.message = dataObject.Message; + } + const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException(); + if (baseExceptionSchema) { + const ErrorCtor = baseExceptionSchema.ctor; + throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject); + } + throw Object.assign(new Error(errorName), errorMetadata, dataObject); + } + } + + /** + * Reads the x-amzn-query-error header for awsQuery compatibility. + * + * @param output - values that will be assigned to an error object. + * @param response - from which to read awsQueryError headers. + */ + public setQueryCompatError(output: Record, response: IHttpResponse) { + const queryErrorHeader = response.headers?.["x-amzn-query-error"]; + + if (output !== undefined && queryErrorHeader != null) { + const [Code, Type] = queryErrorHeader.split(";"); + const entries = Object.entries(output); + const Error = { + Code, + Type, + } as any; + Object.assign(output, Error); + for (const [k, v] of entries) { + Error[k] = v; + } + delete Error.__type; + output.Error = Error; + } + } + + /** + * Assigns Error, Type, Code from the awsQuery error object to the output error object. + * @param queryCompatErrorData - query compat error object. + * @param errorData - canonical error object returned to the caller. + */ + public queryCompatOutput(queryCompatErrorData: any, errorData: any) { + if (queryCompatErrorData.Error) { + errorData.Error = queryCompatErrorData.Error; + } + if (queryCompatErrorData.Type) { + errorData.Type = queryCompatErrorData.Type; + } + if (queryCompatErrorData.Code) { + errorData.Code = queryCompatErrorData.Code; + } + } +} diff --git a/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.spec.ts b/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.spec.ts new file mode 100644 index 0000000000000..646806059cbdb --- /dev/null +++ b/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.spec.ts @@ -0,0 +1,88 @@ +import { cbor } from "@smithy/core/cbor"; +import { op, SCHEMA } from "@smithy/core/schema"; +import { error as registerError } from "@smithy/core/schema"; +import { HttpResponse } from "@smithy/protocol-http"; +import { describe, expect, test as it } from "vitest"; + +import { AwsSmithyRpcV2CborProtocol } from "./AwsSmithyRpcV2CborProtocol"; + +describe(AwsSmithyRpcV2CborProtocol.name, () => { + it("should support awsQueryCompatible", async () => { + const protocol = new AwsSmithyRpcV2CborProtocol({ + defaultNamespace: "ns", + awsQueryCompatible: true, + }); + + class MyQueryError extends Error {} + + registerError( + "ns", + "MyQueryError", + { error: "client" }, + ["Message", "Prop2"], + [SCHEMA.STRING, SCHEMA.NUMERIC], + MyQueryError + ); + + const body = cbor.serialize({ + Message: "oh no", + Prop2: 9999, + }); + + const error = await (async () => { + return protocol.deserializeResponse( + op("ns", "Operation", 0, "unit", "unit"), + {} as any, + new HttpResponse({ + statusCode: 400, + headers: { + "x-amzn-query-error": "MyQueryError;Client", + }, + body, + }) + ); + })().catch((e: any) => e); + + expect(error.$metadata).toEqual({ + cfId: undefined, + extendedRequestId: undefined, + httpStatusCode: 400, + requestId: undefined, + }); + + expect(error.$response).toEqual( + new HttpResponse({ + body, + headers: { + "x-amzn-query-error": "MyQueryError;Client", + }, + reason: undefined, + statusCode: 400, + }) + ); + + expect(error.Code).toEqual(MyQueryError.name); + expect(error.Error.Code).toEqual(MyQueryError.name); + + expect(error.Message).toEqual("oh no"); + expect(error.Prop2).toEqual(9999); + + expect(error.Error.Message).toEqual("oh no"); + expect(error.Error.Prop2).toEqual(9999); + + expect(error).toMatchObject({ + $fault: "client", + Message: "oh no", + message: "oh no", + Prop2: 9999, + Error: { + Code: "MyQueryError", + Message: "oh no", + Type: "Client", + Prop2: 9999, + }, + Type: "Client", + Code: "MyQueryError", + }); + }); +}); diff --git a/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.ts b/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.ts new file mode 100644 index 0000000000000..fec7f4e54d4f0 --- /dev/null +++ b/packages/core/src/submodules/protocols/cbor/AwsSmithyRpcV2CborProtocol.ts @@ -0,0 +1,96 @@ +import { loadSmithyRpcV2CborErrorCode, SmithyRpcV2CborProtocol } from "@smithy/core/cbor"; +import { NormalizedSchema } from "@smithy/core/schema"; +import type { + EndpointBearer, + HandlerExecutionContext, + HttpRequest, + HttpResponse, + OperationSchema, + ResponseMetadata, + SerdeFunctions, +} from "@smithy/types"; + +import { ProtocolLib } from "../ProtocolLib"; + +/** + * Extends the Smithy implementation to add AwsQueryCompatibility support. + * + * @alpha + */ +export class AwsSmithyRpcV2CborProtocol extends SmithyRpcV2CborProtocol { + private readonly awsQueryCompatible: boolean; + private readonly mixin = new ProtocolLib(); + + public constructor({ + defaultNamespace, + awsQueryCompatible, + }: { + defaultNamespace: string; + awsQueryCompatible?: boolean; + }) { + super({ defaultNamespace }); + this.awsQueryCompatible = !!awsQueryCompatible; + } + + /** + * @override + */ + public async serializeRequest( + operationSchema: OperationSchema, + input: Input, + context: HandlerExecutionContext & SerdeFunctions & EndpointBearer + ): Promise { + const request = await super.serializeRequest(operationSchema, input, context); + if (this.awsQueryCompatible) { + request.headers["x-amzn-query-mode"] = "true"; + } + return request; + } + + /** + * @override + */ + protected async handleError( + operationSchema: OperationSchema, + context: HandlerExecutionContext & SerdeFunctions, + response: HttpResponse, + dataObject: any, + metadata: ResponseMetadata + ): Promise { + if (this.awsQueryCompatible) { + this.mixin.setQueryCompatError(dataObject, response); + } + const errorName = loadSmithyRpcV2CborErrorCode(response, dataObject) ?? "Unknown"; + + const { errorSchema, errorMetadata } = await this.mixin.getErrorSchemaOrThrowBaseException( + errorName, + this.options.defaultNamespace, + response, + dataObject, + metadata + ); + + const ns = NormalizedSchema.of(errorSchema); + const message = dataObject.message ?? dataObject.Message ?? "Unknown"; + const exception = new errorSchema.ctor(message); + + const output = {} as any; + for (const [name, member] of ns.structIterator()) { + output[name] = this.deserializer.readValue(member, dataObject[name]); + } + + if (this.awsQueryCompatible) { + this.mixin.queryCompatOutput(dataObject, output); + } + + throw Object.assign( + exception, + errorMetadata, + { + $fault: ns.getMergedTraits().error, + message, + }, + output + ); + } +} diff --git a/packages/core/src/submodules/protocols/common.ts b/packages/core/src/submodules/protocols/common.ts index d8c9785258618..1ffa3bed22082 100644 --- a/packages/core/src/submodules/protocols/common.ts +++ b/packages/core/src/submodules/protocols/common.ts @@ -1,5 +1,6 @@ import { collectBody } from "@smithy/smithy-client"; import type { SerdeFunctions } from "@smithy/types"; +import { toUtf8 } from "@smithy/util-utf8"; export const collectBodyString = (streamBody: any, context: SerdeFunctions): Promise => - collectBody(streamBody, context).then((body) => context.utf8Encoder(body)); + collectBody(streamBody, context).then((body) => (context?.utf8Encoder ?? toUtf8)(body)); diff --git a/packages/core/src/submodules/protocols/index.ts b/packages/core/src/submodules/protocols/index.ts index a93942b0399e2..46678e88c13dc 100644 --- a/packages/core/src/submodules/protocols/index.ts +++ b/packages/core/src/submodules/protocols/index.ts @@ -1,3 +1,4 @@ +export * from "./cbor/AwsSmithyRpcV2CborProtocol"; export * from "./coercing-serializers"; export * from "./json/AwsJson1_0Protocol"; export * from "./json/AwsJson1_1Protocol"; diff --git a/packages/core/src/submodules/protocols/json/AwsJson1_0Protocol.ts b/packages/core/src/submodules/protocols/json/AwsJson1_0Protocol.ts index 8945f534443ae..ab29a88b907e3 100644 --- a/packages/core/src/submodules/protocols/json/AwsJson1_0Protocol.ts +++ b/packages/core/src/submodules/protocols/json/AwsJson1_0Protocol.ts @@ -5,10 +5,19 @@ import { AwsJsonRpcProtocol } from "./AwsJsonRpcProtocol"; * @see https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html#differences-between-awsjson1-0-and-awsjson1-1 */ export class AwsJson1_0Protocol extends AwsJsonRpcProtocol { - public constructor({ defaultNamespace, serviceTarget }: { defaultNamespace: string; serviceTarget: string }) { + public constructor({ + defaultNamespace, + serviceTarget, + awsQueryCompatible, + }: { + defaultNamespace: string; + serviceTarget: string; + awsQueryCompatible?: boolean; + }) { super({ defaultNamespace, serviceTarget, + awsQueryCompatible, }); } diff --git a/packages/core/src/submodules/protocols/json/AwsJson1_1Protocol.ts b/packages/core/src/submodules/protocols/json/AwsJson1_1Protocol.ts index 4b4745ee2046f..77e2064270424 100644 --- a/packages/core/src/submodules/protocols/json/AwsJson1_1Protocol.ts +++ b/packages/core/src/submodules/protocols/json/AwsJson1_1Protocol.ts @@ -5,10 +5,19 @@ import { AwsJsonRpcProtocol } from "./AwsJsonRpcProtocol"; * @see https://smithy.io/2.0/aws/protocols/aws-json-1_1-protocol.html#differences-between-awsjson1-0-and-awsjson1-1 */ export class AwsJson1_1Protocol extends AwsJsonRpcProtocol { - public constructor({ defaultNamespace, serviceTarget }: { defaultNamespace: string; serviceTarget: string }) { + public constructor({ + defaultNamespace, + serviceTarget, + awsQueryCompatible, + }: { + defaultNamespace: string; + serviceTarget: string; + awsQueryCompatible?: boolean; + }) { super({ defaultNamespace, serviceTarget, + awsQueryCompatible, }); } diff --git a/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.spec.ts b/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.spec.ts index d477d5eb71e04..3be2a02d54720 100644 --- a/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.spec.ts +++ b/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.spec.ts @@ -1,24 +1,26 @@ -import { SCHEMA } from "@smithy/core/schema"; +import { error as registerError, op, SCHEMA } from "@smithy/core/schema"; +import { HttpResponse } from "@smithy/protocol-http"; +import { fromUtf8 } from "@smithy/util-utf8"; import { describe, expect, test as it } from "vitest"; import { AwsJsonRpcProtocol } from "./AwsJsonRpcProtocol"; describe(AwsJsonRpcProtocol.name, () => { - it("has expected codec settings", async () => { - const protocol = new (class extends AwsJsonRpcProtocol { - constructor() { - super({ defaultNamespace: "", serviceTarget: "" }); - } + const protocol = new (class extends AwsJsonRpcProtocol { + constructor() { + super({ defaultNamespace: "ns", serviceTarget: "", awsQueryCompatible: true }); + } - getShapeId(): string { - throw new Error("Method not implemented."); - } + getShapeId(): string { + throw new Error("Method not implemented."); + } - protected getJsonRpcVersion(): "1.1" | "1.0" { - throw new Error("Method not implemented."); - } - })(); + protected getJsonRpcVersion(): "1.1" | "1.0" { + throw new Error("Method not implemented."); + } + })(); + it("has expected codec settings", async () => { const codec = protocol.getPayloadCodec(); expect(codec.settings).toEqual({ jsonName: false, @@ -28,4 +30,80 @@ describe(AwsJsonRpcProtocol.name, () => { }, }); }); + + it("should support awsQueryCompatible", async () => { + class MyQueryError extends Error {} + + registerError( + "ns", + "MyQueryError", + { error: "client" }, + ["Message", "Prop2"], + [SCHEMA.STRING, SCHEMA.NUMERIC], + MyQueryError + ); + + const body = fromUtf8( + JSON.stringify({ + Message: "oh no", + Prop2: 9999, + }) + ); + + const error = await (async () => { + return protocol.deserializeResponse( + op("ns", "Operation", 0, "unit", "unit"), + {} as any, + new HttpResponse({ + statusCode: 400, + headers: { + "x-amzn-query-error": "MyQueryError;Client", + }, + body, + }) + ); + })().catch((e: any) => e); + + expect(error.$metadata).toEqual({ + cfId: undefined, + extendedRequestId: undefined, + httpStatusCode: 400, + requestId: undefined, + }); + + expect(error.$response).toEqual( + new HttpResponse({ + body, + headers: { + "x-amzn-query-error": "MyQueryError;Client", + }, + reason: undefined, + statusCode: 400, + }) + ); + + expect(error.Code).toEqual(MyQueryError.name); + expect(error.Error.Code).toEqual(MyQueryError.name); + + expect(error.Message).toEqual("oh no"); + expect(error.Prop2).toEqual(9999); + + expect(error.Error.Message).toEqual("oh no"); + expect(error.Error.Prop2).toEqual(9999); + + expect(error).toMatchObject({ + $fault: "client", + Message: "oh no", + message: "oh no", + Prop2: 9999, + Error: { + Code: "MyQueryError", + Message: "oh no", + Type: "Client", + Prop2: 9999, + }, + Type: "Client", + Code: "MyQueryError", + }); + }); }); diff --git a/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.ts b/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.ts index e309c1a5043b2..115796935327e 100644 --- a/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.ts +++ b/packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.ts @@ -1,6 +1,6 @@ import { RpcProtocol } from "@smithy/core/protocols"; -import { deref, ErrorSchema, NormalizedSchema, SCHEMA, TypeRegistry } from "@smithy/core/schema"; -import { +import { deref, NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import type { EndpointBearer, HandlerExecutionContext, HttpRequest, @@ -11,8 +11,8 @@ import { ShapeDeserializer, ShapeSerializer, } from "@smithy/types"; -import { calculateBodyLength } from "@smithy/util-body-length-browser"; +import { ProtocolLib } from "../ProtocolLib"; import { JsonCodec } from "./JsonCodec"; import { loadRestJsonErrorCode } from "./parseJsonBody"; @@ -23,9 +23,19 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol { protected serializer: ShapeSerializer; protected deserializer: ShapeDeserializer; protected serviceTarget: string; - private codec: JsonCodec; + private readonly codec: JsonCodec; + private readonly mixin = new ProtocolLib(); + private readonly awsQueryCompatible: boolean; - protected constructor({ defaultNamespace, serviceTarget }: { defaultNamespace: string; serviceTarget: string }) { + protected constructor({ + defaultNamespace, + serviceTarget, + awsQueryCompatible, + }: { + defaultNamespace: string; + serviceTarget: string; + awsQueryCompatible?: boolean; + }) { super({ defaultNamespace, }); @@ -39,6 +49,7 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol { }); this.serializer = this.codec.createSerializer(); this.deserializer = this.codec.createDeserializer(); + this.awsQueryCompatible = !!awsQueryCompatible; } public async serializeRequest( @@ -54,11 +65,14 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol { "content-type": `application/x-amz-json-${this.getJsonRpcVersion()}`, "x-amz-target": `${this.serviceTarget}.${NormalizedSchema.of(operationSchema).getName()}`, }); + if (this.awsQueryCompatible) { + request.headers["x-amzn-query-mode"] = "true"; + } if (deref(operationSchema.input) === "unit" || !request.body) { request.body = "{}"; } try { - request.headers["content-length"] = String(calculateBodyLength(request.body)); + request.headers["content-length"] = this.mixin.calculateContentLength(request.body, this.serdeContext); } catch (e) {} return request; } @@ -77,47 +91,33 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol { metadata: ResponseMetadata ): Promise { // loadRestJsonErrorCode is still used in JSON RPC. - const errorIdentifier = loadRestJsonErrorCode(response, dataObject) ?? "Unknown"; - - let namespace = this.options.defaultNamespace; - let errorName = errorIdentifier; - if (errorIdentifier.includes("#")) { - [namespace, errorName] = errorIdentifier.split("#"); + if (this.awsQueryCompatible) { + this.mixin.setQueryCompatError(dataObject, response); } + const errorIdentifier = loadRestJsonErrorCode(response, dataObject) ?? "Unknown"; - const errorMetadata = { - $metadata: metadata, - $response: response, - $fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const), - }; - - const registry = TypeRegistry.for(namespace); - let errorSchema: ErrorSchema; - try { - errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema; - } catch (e) { - if (dataObject.Message) { - dataObject.message = dataObject.Message; - } - const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException(); - if (baseExceptionSchema) { - const ErrorCtor = baseExceptionSchema.ctor; - throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject); - } - throw Object.assign(new Error(errorName), errorMetadata, dataObject); - } + const { errorSchema, errorMetadata } = await this.mixin.getErrorSchemaOrThrowBaseException( + errorIdentifier, + this.options.defaultNamespace, + response, + dataObject, + metadata + ); const ns = NormalizedSchema.of(errorSchema); const message = dataObject.message ?? dataObject.Message ?? "Unknown"; const exception = new errorSchema.ctor(message); - await this.deserializeHttpMessage(errorSchema, context, response, dataObject); const output = {} as any; for (const [name, member] of ns.structIterator()) { const target = member.getMergedTraits().jsonName ?? name; output[name] = this.codec.createDeserializer().readObject(member, dataObject[target]); } + if (this.awsQueryCompatible) { + this.mixin.queryCompatOutput(dataObject, output); + } + throw Object.assign( exception, errorMetadata, diff --git a/packages/core/src/submodules/protocols/json/AwsRestJsonProtocol.ts b/packages/core/src/submodules/protocols/json/AwsRestJsonProtocol.ts index 03b36be59a734..63fb7354a7cd3 100644 --- a/packages/core/src/submodules/protocols/json/AwsRestJsonProtocol.ts +++ b/packages/core/src/submodules/protocols/json/AwsRestJsonProtocol.ts @@ -3,8 +3,8 @@ import { HttpInterceptingShapeDeserializer, HttpInterceptingShapeSerializer, } from "@smithy/core/protocols"; -import { ErrorSchema, NormalizedSchema, SCHEMA, TypeRegistry } from "@smithy/core/schema"; -import { +import { NormalizedSchema, SCHEMA } from "@smithy/core/schema"; +import type { EndpointBearer, HandlerExecutionContext, HttpRequest, @@ -15,8 +15,8 @@ import { ShapeDeserializer, ShapeSerializer, } from "@smithy/types"; -import { calculateBodyLength } from "@smithy/util-body-length-browser"; +import { ProtocolLib } from "../ProtocolLib"; import { JsonCodec, JsonSettings } from "./JsonCodec"; import { loadRestJsonErrorCode } from "./parseJsonBody"; @@ -27,6 +27,7 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol { protected serializer: ShapeSerializer; protected deserializer: ShapeDeserializer; private readonly codec: JsonCodec; + private readonly mixin = new ProtocolLib(); public constructor({ defaultNamespace }: { defaultNamespace: string }) { super({ @@ -65,32 +66,11 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol { ): Promise { const request = await super.serializeRequest(operationSchema, input, context); const inputSchema = NormalizedSchema.of(operationSchema.input); - const members = inputSchema.getMemberSchemas(); if (!request.headers["content-type"]) { - const httpPayloadMember = Object.values(members).find((m) => { - return !!m.getMergedTraits().httpPayload; - }); - - if (httpPayloadMember) { - const mediaType = httpPayloadMember.getMergedTraits().mediaType as string; - if (mediaType) { - request.headers["content-type"] = mediaType; - } else if (httpPayloadMember.isStringSchema()) { - request.headers["content-type"] = "text/plain"; - } else if (httpPayloadMember.isBlobSchema()) { - request.headers["content-type"] = "application/octet-stream"; - } else { - request.headers["content-type"] = this.getDefaultContentType(); - } - } else if (!inputSchema.isUnitSchema()) { - const hasBody = Object.values(members).find((m) => { - const { httpQuery, httpQueryParams, httpHeader, httpLabel, httpPrefixHeaders } = m.getMergedTraits(); - return !httpQuery && !httpQueryParams && !httpHeader && !httpLabel && httpPrefixHeaders === void 0; - }); - if (hasBody) { - request.headers["content-type"] = this.getDefaultContentType(); - } + const contentType = this.mixin.resolveRestContentType(this.getDefaultContentType(), inputSchema); + if (contentType) { + request.headers["content-type"] = contentType; } } @@ -100,8 +80,7 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol { if (request.body) { try { - // todo(schema): use config.bodyLengthChecker or move that into serdeContext. - request.headers["content-length"] = String(calculateBodyLength(request.body)); + request.headers["content-length"] = this.mixin.calculateContentLength(request.body, this.serdeContext); } catch (e) {} } @@ -117,33 +96,13 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol { ): Promise { const errorIdentifier = loadRestJsonErrorCode(response, dataObject) ?? "Unknown"; - let namespace = this.options.defaultNamespace; - let errorName = errorIdentifier; - if (errorIdentifier.includes("#")) { - [namespace, errorName] = errorIdentifier.split("#"); - } - - const errorMetadata = { - $metadata: metadata, - $response: response, - $fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const), - }; - - const registry = TypeRegistry.for(namespace); - let errorSchema: ErrorSchema; - try { - errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema; - } catch (e) { - if (dataObject.Message) { - dataObject.message = dataObject.Message; - } - const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException(); - if (baseExceptionSchema) { - const ErrorCtor = baseExceptionSchema.ctor; - throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject); - } - throw Object.assign(new Error(errorName), errorMetadata, dataObject); - } + const { errorSchema, errorMetadata } = await this.mixin.getErrorSchemaOrThrowBaseException( + errorIdentifier, + this.options.defaultNamespace, + response, + dataObject, + metadata + ); const ns = NormalizedSchema.of(errorSchema); const message = dataObject.message ?? dataObject.Message ?? "Unknown"; diff --git a/packages/core/src/submodules/protocols/query/AwsQueryProtocol.ts b/packages/core/src/submodules/protocols/query/AwsQueryProtocol.ts index f695500dde550..4a65c265c8cdf 100644 --- a/packages/core/src/submodules/protocols/query/AwsQueryProtocol.ts +++ b/packages/core/src/submodules/protocols/query/AwsQueryProtocol.ts @@ -1,18 +1,18 @@ import { collectBody, RpcProtocol } from "@smithy/core/protocols"; import { deref, ErrorSchema, NormalizedSchema, SCHEMA, TypeRegistry } from "@smithy/core/schema"; -import { +import type { Codec, EndpointBearer, HandlerExecutionContext, HttpRequest, + HttpResponse as IHttpResponse, MetadataBearer, OperationSchema, ResponseMetadata, SerdeFunctions, } from "@smithy/types"; -import type { HttpResponse as IHttpResponse } from "@smithy/types/dist-types/http"; -import { calculateBodyLength } from "@smithy/util-body-length-browser"; +import { ProtocolLib } from "../ProtocolLib"; import { XmlShapeDeserializer } from "../xml/XmlShapeDeserializer"; import { QueryShapeSerializer } from "./QueryShapeSerializer"; @@ -22,6 +22,7 @@ import { QueryShapeSerializer } from "./QueryShapeSerializer"; export class AwsQueryProtocol extends RpcProtocol { protected serializer: QueryShapeSerializer; protected deserializer: XmlShapeDeserializer; + private readonly mixin = new ProtocolLib(); public constructor( public options: { @@ -81,7 +82,7 @@ export class AwsQueryProtocol extends RpcProtocol { } try { - request.headers["content-length"] = String(calculateBodyLength(request.body)); + request.headers["content-length"] = this.mixin.calculateContentLength(request.body, this.serdeContext); } catch (e) {} return request; } @@ -140,40 +141,19 @@ export class AwsQueryProtocol extends RpcProtocol { metadata: ResponseMetadata ): Promise { const errorIdentifier = this.loadQueryErrorCode(response, dataObject) ?? "Unknown"; - let namespace = this.options.defaultNamespace; - let errorName = errorIdentifier; - if (errorIdentifier.includes("#")) { - [namespace, errorName] = errorIdentifier.split("#"); - } - const errorData = this.loadQueryError(dataObject); - const errorMetadata = { - $metadata: metadata, - $response: response, - $fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const), - }; - const registry = TypeRegistry.for(namespace); - let errorSchema: ErrorSchema; - - try { - errorSchema = registry.find( - (schema) => (NormalizedSchema.of(schema).getMergedTraits().awsQueryError as any)?.[0] === errorName - ) as ErrorSchema; - if (!errorSchema) { - errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema; - } - } catch (e) { - if (errorData.Message) { - errorData.message = errorData.Message; - } - const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException(); - if (baseExceptionSchema) { - const ErrorCtor = baseExceptionSchema.ctor; - throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject); - } - throw Object.assign(new Error(errorName), errorMetadata, errorData); - } + const { errorSchema, errorMetadata } = await this.mixin.getErrorSchemaOrThrowBaseException( + errorIdentifier, + this.options.defaultNamespace, + response, + errorData, + metadata, + (registry: TypeRegistry, errorName: string) => + registry.find( + (schema) => (NormalizedSchema.of(schema).getMergedTraits().awsQueryError as any)?.[0] === errorName + ) as ErrorSchema + ); const ns = NormalizedSchema.of(errorSchema); const message = this.loadQueryErrorMessage(dataObject); diff --git a/packages/core/src/submodules/protocols/xml/AwsRestXmlProtocol.ts b/packages/core/src/submodules/protocols/xml/AwsRestXmlProtocol.ts index 2d26a0c15ef86..b915d2ba6700c 100644 --- a/packages/core/src/submodules/protocols/xml/AwsRestXmlProtocol.ts +++ b/packages/core/src/submodules/protocols/xml/AwsRestXmlProtocol.ts @@ -3,7 +3,7 @@ import { HttpInterceptingShapeDeserializer, HttpInterceptingShapeSerializer, } from "@smithy/core/protocols"; -import { ErrorSchema, NormalizedSchema, OperationSchema, SCHEMA, TypeRegistry } from "@smithy/core/schema"; +import { NormalizedSchema, OperationSchema, SCHEMA } from "@smithy/core/schema"; import type { EndpointBearer, HandlerExecutionContext, @@ -15,8 +15,8 @@ import type { ShapeDeserializer, ShapeSerializer, } from "@smithy/types"; -import { calculateBodyLength } from "@smithy/util-body-length-browser"; +import { ProtocolLib } from "../ProtocolLib"; import { loadRestXmlErrorCode } from "./parseXmlBody"; import { XmlCodec, XmlSettings } from "./XmlCodec"; @@ -27,6 +27,7 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol { private readonly codec: XmlCodec; protected serializer: ShapeSerializer; protected deserializer: ShapeDeserializer; + private readonly mixin = new ProtocolLib(); public constructor(options: { defaultNamespace: string; xmlNamespace: string }) { super(options); @@ -58,33 +59,12 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol { context: HandlerExecutionContext & SerdeFunctions & EndpointBearer ): Promise { const request = await super.serializeRequest(operationSchema, input, context); - const ns = NormalizedSchema.of(operationSchema.input); - const members = ns.getMemberSchemas(); + const inputSchema = NormalizedSchema.of(operationSchema.input); if (!request.headers["content-type"]) { - const httpPayloadMember = Object.values(members).find((m) => { - return !!m.getMergedTraits().httpPayload; - }); - - if (httpPayloadMember) { - const mediaType = httpPayloadMember.getMergedTraits().mediaType as string; - if (mediaType) { - request.headers["content-type"] = mediaType; - } else if (httpPayloadMember.isStringSchema()) { - request.headers["content-type"] = "text/plain"; - } else if (httpPayloadMember.isBlobSchema()) { - request.headers["content-type"] = "application/octet-stream"; - } else { - request.headers["content-type"] = this.getDefaultContentType(); - } - } else if (!ns.isUnitSchema()) { - const hasBody = Object.values(members).find((m) => { - const { httpQuery, httpQueryParams, httpHeader, httpLabel, httpPrefixHeaders } = m.getMergedTraits(); - return !httpQuery && !httpQueryParams && !httpHeader && !httpLabel && httpPrefixHeaders === void 0; - }); - if (hasBody) { - request.headers["content-type"] = this.getDefaultContentType(); - } + const contentType = this.mixin.resolveRestContentType(this.getDefaultContentType(), inputSchema); + if (contentType) { + request.headers["content-type"] = contentType; } } @@ -96,8 +76,7 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol { if (request.body) { try { - // todo(schema): use config.bodyLengthChecker or move that into serdeContext. - request.headers["content-length"] = String(calculateBodyLength(request.body)); + request.headers["content-length"] = this.mixin.calculateContentLength(request.body, this.serdeContext); } catch (e) {} } @@ -120,34 +99,14 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol { metadata: ResponseMetadata ): Promise { const errorIdentifier = loadRestXmlErrorCode(response, dataObject) ?? "Unknown"; - let namespace = this.options.defaultNamespace; - let errorName = errorIdentifier; - if (errorIdentifier.includes("#")) { - [namespace, errorName] = errorIdentifier.split("#"); - } - - const errorMetadata = { - $metadata: metadata, - $response: response, - $fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const), - }; - - const registry = TypeRegistry.for(namespace); - let errorSchema: ErrorSchema; - try { - errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema; - } catch (e) { - if (dataObject.Message) { - dataObject.message = dataObject.Message; - } - const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException(); - if (baseExceptionSchema) { - const ErrorCtor = baseExceptionSchema.ctor; - throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject); - } - throw Object.assign(new Error(errorName), errorMetadata, dataObject); - } + const { errorSchema, errorMetadata } = await this.mixin.getErrorSchemaOrThrowBaseException( + errorIdentifier, + this.options.defaultNamespace, + response, + dataObject, + metadata + ); const ns = NormalizedSchema.of(errorSchema); const message = diff --git a/private/aws-protocoltests-json-10-schema/src/runtimeConfig.shared.ts b/private/aws-protocoltests-json-10-schema/src/runtimeConfig.shared.ts index b1369d8a2e49e..18ef77410e6ce 100644 --- a/private/aws-protocoltests-json-10-schema/src/runtimeConfig.shared.ts +++ b/private/aws-protocoltests-json-10-schema/src/runtimeConfig.shared.ts @@ -33,7 +33,11 @@ export const getRuntimeConfig = (config: JSONRPC10ClientConfig) => { logger: config?.logger ?? new NoOpLogger(), protocol: config?.protocol ?? - new AwsJson1_0Protocol({ defaultNamespace: "aws.protocoltests.json10", serviceTarget: "JsonRpc10" }), + new AwsJson1_0Protocol({ + defaultNamespace: "aws.protocoltests.json10", + serviceTarget: "JsonRpc10", + awsQueryCompatible: false, + }), serviceId: config?.serviceId ?? "JSON RPC 10", urlParser: config?.urlParser ?? parseUrl, utf8Decoder: config?.utf8Decoder ?? fromUtf8, diff --git a/private/aws-protocoltests-json-schema/src/runtimeConfig.shared.ts b/private/aws-protocoltests-json-schema/src/runtimeConfig.shared.ts index c5c69b5bdf817..ff36005a76f85 100644 --- a/private/aws-protocoltests-json-schema/src/runtimeConfig.shared.ts +++ b/private/aws-protocoltests-json-schema/src/runtimeConfig.shared.ts @@ -33,7 +33,11 @@ export const getRuntimeConfig = (config: JsonProtocolClientConfig) => { logger: config?.logger ?? new NoOpLogger(), protocol: config?.protocol ?? - new AwsJson1_1Protocol({ defaultNamespace: "aws.protocoltests.json", serviceTarget: "JsonProtocol" }), + new AwsJson1_1Protocol({ + defaultNamespace: "aws.protocoltests.json", + serviceTarget: "JsonProtocol", + awsQueryCompatible: false, + }), serviceId: config?.serviceId ?? "Json Protocol", urlParser: config?.urlParser ?? parseUrl, utf8Decoder: config?.utf8Decoder ?? fromUtf8, diff --git a/private/aws-protocoltests-smithy-rpcv2-cbor-schema/package.json b/private/aws-protocoltests-smithy-rpcv2-cbor-schema/package.json index 3f8c92b6423fc..b55a708e947ac 100644 --- a/private/aws-protocoltests-smithy-rpcv2-cbor-schema/package.json +++ b/private/aws-protocoltests-smithy-rpcv2-cbor-schema/package.json @@ -20,6 +20,7 @@ "dependencies": { "@aws-crypto/sha256-browser": "5.2.0", "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "*", "@aws-sdk/middleware-host-header": "*", "@aws-sdk/middleware-logger": "*", "@aws-sdk/middleware-recursion-detection": "*", diff --git a/private/aws-protocoltests-smithy-rpcv2-cbor-schema/src/runtimeConfig.shared.ts b/private/aws-protocoltests-smithy-rpcv2-cbor-schema/src/runtimeConfig.shared.ts index 059f0767556ab..51a15879b9703 100644 --- a/private/aws-protocoltests-smithy-rpcv2-cbor-schema/src/runtimeConfig.shared.ts +++ b/private/aws-protocoltests-smithy-rpcv2-cbor-schema/src/runtimeConfig.shared.ts @@ -1,6 +1,6 @@ // smithy-typescript generated code +import { AwsSmithyRpcV2CborProtocol } from "@aws-sdk/core/protocols"; import { NoAuthSigner } from "@smithy/core"; -import { SmithyRpcV2CborProtocol } from "@smithy/core/cbor"; import { NoOpLogger } from "@smithy/smithy-client"; import { IdentityProviderConfig } from "@smithy/types"; import { parseUrl } from "@smithy/url-parser"; @@ -32,7 +32,12 @@ export const getRuntimeConfig = (config: RpcV2ProtocolClientConfig) => { }, ], logger: config?.logger ?? new NoOpLogger(), - protocol: config?.protocol ?? new SmithyRpcV2CborProtocol({ defaultNamespace: "smithy.protocoltests.rpcv2Cbor" }), + protocol: + config?.protocol ?? + new AwsSmithyRpcV2CborProtocol({ + defaultNamespace: "smithy.protocoltests.rpcv2Cbor", + awsQueryCompatible: false, + }), urlParser: config?.urlParser ?? parseUrl, utf8Decoder: config?.utf8Decoder ?? fromUtf8, utf8Encoder: config?.utf8Encoder ?? toUtf8, diff --git a/yarn.lock b/yarn.lock index c5e743ac48fa9..423c240161add 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1262,6 +1262,7 @@ __metadata: dependencies: "@aws-crypto/sha256-browser": "npm:5.2.0" "@aws-crypto/sha256-js": "npm:5.2.0" + "@aws-sdk/core": "npm:*" "@aws-sdk/middleware-host-header": "npm:*" "@aws-sdk/middleware-logger": "npm:*" "@aws-sdk/middleware-recursion-detection": "npm:*"