Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/middleware-flexible-checksums/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@smithy/types": "^3.6.0",
"@smithy/util-middleware": "^3.0.8",
"@smithy/util-utf8": "^3.0.0",
"@smithy/util-stream": "^3.2.1",
"tslib": "^2.6.2"
},
"devDependencies": {
Expand Down
16 changes: 1 addition & 15 deletions packages/middleware-flexible-checksums/src/getChecksum.spec.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { getChecksum } from "./getChecksum";
import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

vi.mock("./isStreaming");
vi.mock("./stringHasher");

describe(getChecksum.name, () => {
const mockOptions = {
streamHasher: vi.fn(),
checksumAlgorithmFn: vi.fn(),
base64Encoder: vi.fn(),
};
Expand All @@ -26,21 +23,10 @@ describe(getChecksum.name, () => {
vi.clearAllMocks();
});

it("gets checksum from streamHasher if body is streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
mockOptions.streamHasher.mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(stringHasher).not.toHaveBeenCalled();
expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});

it("gets checksum from stringHasher if body is not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);
it("gets checksum from stringHasher", async () => {
vi.mocked(stringHasher).mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(mockOptions.streamHasher).not.toHaveBeenCalled();
expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});
});
13 changes: 3 additions & 10 deletions packages/middleware-flexible-checksums/src/getChecksum.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types";
import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types";

import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

export interface GetChecksumDigestOptions {
streamHasher: StreamHasher<any>;
checksumAlgorithmFn: ChecksumConstructor | HashConstructor;
base64Encoder: Encoder;
}

export const getChecksum = async (
body: unknown,
{ streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions
) => {
const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body);
return base64Encoder(await digest);
};
export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) =>
base64Encoder(await stringHasher(checksumAlgorithmFn, body));
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import { HttpResponse } from "@smithy/protocol-http";
import { createChecksumStream } from "@smithy/util-stream";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

vi.mock("@smithy/util-stream");
vi.mock("./getChecksum");
vi.mock("./getChecksumLocationName");
vi.mock("./getChecksumAlgorithmListForResponse");
vi.mock("./isStreaming");
vi.mock("./selectChecksumAlgorithmFunction");

describe(validateChecksumFromResponse.name, () => {
const mockConfig = {
streamHasher: vi.fn(),
base64Encoder: vi.fn(),
} as unknown as PreviouslyResolved;

const mockBody = {};
const mockBodyStream = { isStream: true };
const mockHeaders = {};
const mockResponse = {
body: mockBody,
Expand Down Expand Up @@ -50,6 +54,7 @@ describe(validateChecksumFromResponse.name, () => {
vi.mocked(getChecksumAlgorithmListForResponse).mockImplementation((responseAlgorithms) => responseAlgorithms);
vi.mocked(selectChecksumAlgorithmFunction).mockReturnValue(mockChecksumAlgorithmFn);
vi.mocked(getChecksum).mockResolvedValue(mockChecksum);
vi.mocked(createChecksumStream).mockReturnValue(mockBodyStream);
});

afterEach(() => {
Expand Down Expand Up @@ -85,31 +90,56 @@ describe(validateChecksumFromResponse.name, () => {
});

describe("successful validation", () => {
afterEach(() => {
const validateCalls = (isStream: boolean, checksumAlgoFn: ChecksumAlgorithm) => {
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
});

it("when checksum is populated for first algorithm", async () => {
if (isStream) {
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(createChecksumStream).toHaveBeenCalledWith({
expectedChecksum: mockChecksum,
checksumSourceLocation: checksumAlgoFn,
checksum: new mockChecksumAlgorithmFn(),
source: mockBody,
base64Encoder: mockConfig.base64Encoder,
});
} else {
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledWith(mockBody, {
checksumAlgorithmFn: mockChecksumAlgorithmFn,
base64Encoder: mockConfig.base64Encoder,
});
expect(createChecksumStream).not.toHaveBeenCalled();
}
};

it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]);
validateCalls(isStream, mockResponseAlgorithms[0]);
});

it("when checksum is populated for second algorithm", async () => {
it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(2);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]);
validateCalls(isStream, mockResponseAlgorithms[1]);
});
});

it("throw error if checksum value is not accurate", async () => {
it("throw error if checksum value is not accurate when not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);

const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

try {
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
fail("should throw checksum mismatch error");
Expand All @@ -119,9 +149,28 @@ describe(validateChecksumFromResponse.name, () => {
` in response header "${mockResponseAlgorithms[0]}".`
);
}

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(createChecksumStream).not.toHaveBeenCalled();
});

it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => {
vi.mocked(isStreaming).mockReturnValue(true);

// This override does not matter for the purpose of unit test, but is kept for completeness.
const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);

await validateChecksumFromResponse(responseWithChecksum, mockOptions);

expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
expect(responseWithChecksum.body).toBe(mockBodyStream);
});
});
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { HttpResponse } from "@smithy/protocol-http";
import { Checksum } from "@smithy/types";
import { createChecksumStream } from "@smithy/util-stream";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";

export interface ValidateChecksumFromResponseOptions {
Expand All @@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async (
const checksumFromResponse = responseHeaders[responseHeader];
if (checksumFromResponse) {
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config);
const { streamHasher, base64Encoder } = config;
const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
const { base64Encoder } = config;

if (isStreaming(responseBody)) {
response.body = createChecksumStream({
expectedChecksum: checksumFromResponse,
checksumSourceLocation: responseHeader,
checksum: new checksumAlgorithmFn() as Checksum,
source: responseBody,
base64Encoder,
});
return;
}

const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder });
if (checksum === checksumFromResponse) {
// The checksum for response payload is valid.
break;
Expand Down
Loading