Skip to content
Closed
13 changes: 13 additions & 0 deletions packages/middleware-flexible-checksums/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import {
Encoder,
GetAwsChunkedEncodingStream,
HashConstructor,
Provider,
StreamCollector,
StreamHasher,
} from "@smithy/types";

import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface PreviouslyResolved {
/**
* The function that will be used to convert binary data to a base64-encoded string.
Expand All @@ -31,6 +34,16 @@ export interface PreviouslyResolved {
*/
md5: ChecksumConstructor | HashConstructor;

/**
* Determines when a checksum will be calculated for request payloads
*/
requestChecksumCalculation: Provider<RequestChecksumCalculation>;

/**
* Determines when a checksum will be calculated for response payloads
*/
responseChecksumValidation: Provider<ResponseChecksumValidation>;

/**
* A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes.
* @internal
Expand Down
5 changes: 4 additions & 1 deletion packages/middleware-flexible-checksums/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ export const DEFAULT_RESPONSE_CHECKSUM_VALIDATION = RequestChecksumCalculation.W
* Checksum Algorithms supported by the SDK.
*/
export enum ChecksumAlgorithm {
/**
* @deprecated Use {@link ChecksumAlgorithm.CRC32} instead.
*/
MD5 = "MD5",
CRC32 = "CRC32",
CRC32C = "CRC32C",
Expand All @@ -70,7 +73,7 @@ export enum ChecksumLocation {
/**
* @internal
*/
export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.MD5;
export const DEFAULT_CHECKSUM_ALGORITHM = ChecksumAlgorithm.CRC32;

/**
* @internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import { BuildHandlerArguments } from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, RequestChecksumCalculation } from "./constants";
import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
Expand All @@ -27,7 +27,9 @@ describe(flexibleChecksumsMiddleware.name, () => {
const mockChecksumLocationName = "mock-checksum-location-name";

const mockInput = {};
const mockConfig = {} as PreviouslyResolved;
const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false };

const mockBody = { body: "mockRequestBody" };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ export const flexibleChecksumsMiddleware =
const { request } = args;
const { body: requestBody, headers } = request;
const { base64Encoder, streamHasher } = config;
const requestChecksumCalculation = await config.requestChecksumCalculation();
const { input, requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;

const checksumAlgorithm = getChecksumAlgorithmForRequest(
input,
{
requestChecksumRequired,
requestAlgorithmMember,
requestChecksumCalculation,
},
!!context.isS3ExpressBucket
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { HttpRequest } from "@smithy/protocol-http";
import { DeserializeHandlerArguments } from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants";
import { flexibleChecksumsResponseMiddleware } from "./flexibleChecksumsResponseMiddleware";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin";
Expand All @@ -23,7 +23,9 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
commandName: "mockCommandName",
};

const mockConfig = {} as PreviouslyResolved;
const mockConfig = {
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockRequestValidationModeMember = "ChecksumEnabled";
const mockResponseAlgorithms = [ChecksumAlgorithm.CRC32, ChecksumAlgorithm.CRC32C];
const mockMiddlewareConfig = {
Expand Down Expand Up @@ -59,52 +61,66 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
});

describe("skips", () => {
it("if not an instance of HttpRequest", async () => {
const { isInstance } = HttpRequest;
(isInstance as unknown as jest.Mock).mockReturnValue(false);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);
it("if requestValidationModeMember is not defined", async () => {
const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig;
delete mockMwConfig.requestValidationModeMember;
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext);
await handler(mockArgs);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

describe("response checksum", () => {
it("if requestValidationModeMember is not defined", async () => {
const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig;
delete mockMwConfig.requestValidationModeMember;
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext);
await handler(mockArgs);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
});
it("if requestValidationModeMember is not enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);

it("if requestValidationModeMember is not enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);
await handler({ ...mockArgs, input: {} });
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
});
const mockArgsWithoutEnabled = { ...mockArgs, input: {} };
await handler(mockArgsWithoutEnabled);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgsWithoutEnabled);
});

it("if checksum is for S3 whole-object multipart GET", async () => {
(isChecksumWithPartNumber as jest.Mock).mockReturnValue(true);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {
clientName: "S3Client",
commandName: "GetObjectCommand",
});
await handler(mockArgs);
expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1);
expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
it("if checksum is for S3 whole-object multipart GET", async () => {
(isChecksumWithPartNumber as jest.Mock).mockReturnValue(true);
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {
clientName: "S3Client",
commandName: "GetObjectCommand",
});
await handler(mockArgs);
expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1);
expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum);
expect(validateChecksumFromResponse).not.toHaveBeenCalled();
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});

describe("validates checksum from response header", () => {
it("generic case", async () => {
it("if requestValidationModeMember is enabled in input", async () => {
const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext);

await handler(mockArgs);
expect(validateChecksumFromResponse).toHaveBeenCalledWith(mockResult.response, {
config: mockConfig,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it(`if requestValidationModeMember is not enabled in input, but responseChecksumValidation returns ${ResponseChecksumValidation.WHEN_SUPPORTED}`, async () => {
const mockConfigWithResponseChecksumValidationSupported = {
...mockConfig,
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED),
};
const handler = flexibleChecksumsResponseMiddleware(
mockConfigWithResponseChecksumValidationSupported,
mockMiddlewareConfig
)(mockNext, mockContext);

await handler({ ...mockArgs, input: {} });
expect(validateChecksumFromResponse).toHaveBeenCalledWith(mockResult.response, {
config: mockConfigWithResponseChecksumValidationSupported,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it("if checksum is for S3 GET without part number", async () => {
Expand All @@ -120,6 +136,7 @@ describe(flexibleChecksumsResponseMiddleware.name, () => {
config: mockConfig,
responseAlgorithms: mockResponseAlgorithms,
});
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { HttpRequest, HttpResponse } from "@smithy/protocol-http";
import { HttpResponse } from "@smithy/protocol-http";
import {
DeserializeHandler,
DeserializeHandlerArguments,
DeserializeHandlerOutput,
DeserializeMiddleware,
HandlerExecutionContext,
MetadataBearer,
RelativeMiddlewareOptions,
SerializeHandler,
SerializeHandlerArguments,
SerializeHandlerOutput,
SerializeMiddleware,
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isChecksumWithPartNumber } from "./isChecksumWithPartNumber";
Expand All @@ -37,8 +37,8 @@ export interface FlexibleChecksumsResponseMiddlewareConfig {
*/
export const flexibleChecksumsResponseMiddlewareOptions: RelativeMiddlewareOptions = {
name: "flexibleChecksumsResponseMiddleware",
toMiddleware: "deserializerMiddleware",
relation: "after",
toMiddleware: "serializerMiddleware",
relation: "before",
tags: ["BODY_CHECKSUM"],
override: true,
};
Expand All @@ -52,32 +52,38 @@ export const flexibleChecksumsResponseMiddleware =
(
config: PreviouslyResolved,
middlewareConfig: FlexibleChecksumsResponseMiddlewareConfig
): DeserializeMiddleware<any, any> =>
): SerializeMiddleware<any, any> =>
<Output extends MetadataBearer>(
next: DeserializeHandler<any, Output>,
next: SerializeHandler<any, Output>,
context: HandlerExecutionContext
): DeserializeHandler<any, Output> =>
async (args: DeserializeHandlerArguments<any>): Promise<DeserializeHandlerOutput<Output>> => {
if (!HttpRequest.isInstance(args.request)) {
return next(args);
): SerializeHandler<any, Output> =>
async (args: SerializeHandlerArguments<any>): Promise<SerializeHandlerOutput<Output>> => {
const input = args.input;
const { requestValidationModeMember, responseAlgorithms } = middlewareConfig;
const responseChecksumValidation = await config.responseChecksumValidation();

const isResponseChecksumValidationNeeded =
requestValidationModeMember &&
(input[requestValidationModeMember] === "ENABLED" ||
responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED);

if (isResponseChecksumValidationNeeded) {
input[requestValidationModeMember] = "ENABLED";
}

const input = args.input;
const result = await next(args);

const response = result.response as HttpResponse;
let collectedStream: Uint8Array | undefined = undefined;

const { requestValidationModeMember, responseAlgorithms } = middlewareConfig;
// @ts-ignore Element implicitly has an 'any' type for input[requestValidationModeMember]
if (requestValidationModeMember && input[requestValidationModeMember] === "ENABLED") {
if (isResponseChecksumValidationNeeded) {
const { clientName, commandName } = context;
const isS3WholeObjectMultipartGetResponseChecksum =
clientName === "S3Client" &&
commandName === "GetObjectCommand" &&
getChecksumAlgorithmListForResponse(responseAlgorithms).every((algorithm: ChecksumAlgorithm) => {
const responseHeader = getChecksumLocationName(algorithm);
const checksumFromResponse = response.headers[responseHeader];
const checksumFromResponse = response.headers?.[responseHeader];
return !checksumFromResponse || isChecksumWithPartNumber(checksumFromResponse);
});
if (isS3WholeObjectMultipartGetResponseChecksum) {
Expand Down
Loading