Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/middleware-flexible-checksums/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@smithy/types": "^3.6.0",
"@smithy/util-middleware": "^3.0.8",
"@smithy/util-utf8": "^3.0.0",
"@smithy/util-stream": "^3.2.1",
"tslib": "^2.6.2"
},
"devDependencies": {
Expand Down
16 changes: 1 addition & 15 deletions packages/middleware-flexible-checksums/src/getChecksum.spec.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { getChecksum } from "./getChecksum";
import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

vi.mock("./isStreaming");
vi.mock("./stringHasher");

describe(getChecksum.name, () => {
const mockOptions = {
streamHasher: vi.fn(),
checksumAlgorithmFn: vi.fn(),
base64Encoder: vi.fn(),
};
Expand All @@ -26,21 +23,10 @@ describe(getChecksum.name, () => {
vi.clearAllMocks();
});

it("gets checksum from streamHasher if body is streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
mockOptions.streamHasher.mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(stringHasher).not.toHaveBeenCalled();
expect(mockOptions.streamHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});

it("gets checksum from stringHasher if body is not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);
it("gets checksum from stringHasher", async () => {
vi.mocked(stringHasher).mockResolvedValue(mockRawOutput);
const checksum = await getChecksum(mockBody, mockOptions);
expect(checksum).toEqual(mockOutput);
expect(mockOptions.streamHasher).not.toHaveBeenCalled();
expect(stringHasher).toHaveBeenCalledWith(mockOptions.checksumAlgorithmFn, mockBody);
});
});
13 changes: 3 additions & 10 deletions packages/middleware-flexible-checksums/src/getChecksum.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import { ChecksumConstructor, Encoder, HashConstructor, StreamHasher } from "@smithy/types";
import { ChecksumConstructor, Encoder, HashConstructor } from "@smithy/types";

import { isStreaming } from "./isStreaming";
import { stringHasher } from "./stringHasher";

export interface GetChecksumDigestOptions {
streamHasher: StreamHasher<any>;
checksumAlgorithmFn: ChecksumConstructor | HashConstructor;
base64Encoder: Encoder;
}

export const getChecksum = async (
body: unknown,
{ streamHasher, checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions
) => {
const digest = isStreaming(body) ? streamHasher(checksumAlgorithmFn, body) : stringHasher(checksumAlgorithmFn, body);
return base64Encoder(await digest);
};
export const getChecksum = async (body: unknown, { checksumAlgorithmFn, base64Encoder }: GetChecksumDigestOptions) =>
base64Encoder(await stringHasher(checksumAlgorithmFn, body));
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import { HttpResponse } from "@smithy/protocol-http";
import { createChecksumStream } from "@smithy/util-stream";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { validateChecksumFromResponse } from "./validateChecksumFromResponse";

vi.mock("@smithy/util-stream");
vi.mock("./getChecksum");
vi.mock("./getChecksumLocationName");
vi.mock("./getChecksumAlgorithmListForResponse");
vi.mock("./isStreaming");
vi.mock("./selectChecksumAlgorithmFunction");

describe(validateChecksumFromResponse.name, () => {
const mockConfig = {
streamHasher: vi.fn(),
base64Encoder: vi.fn(),
} as unknown as PreviouslyResolved;

Expand Down Expand Up @@ -85,29 +88,41 @@ describe(validateChecksumFromResponse.name, () => {
});

describe("successful validation", () => {
afterEach(() => {
const validateCalls = (isStream: boolean) => {
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
});

it("when checksum is populated for first algorithm", async () => {
if (isStream) {
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
} else {
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(createChecksumStream).not.toHaveBeenCalled();
}
};

it.each([false, true])("when checksum is populated for first algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledWith(mockResponseAlgorithms[0]);
validateCalls(isStream);
});

it("when checksum is populated for second algorithm", async () => {
it.each([false, true])("when checksum is populated for second algorithm when streaming: %s", async (isStream) => {
vi.mocked(isStreaming).mockReturnValue(isStream);
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[1], mockChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumLocationName).toHaveBeenCalledTimes(2);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(1, mockResponseAlgorithms[0]);
expect(getChecksumLocationName).toHaveBeenNthCalledWith(2, mockResponseAlgorithms[1]);
validateCalls(isStream);
});
});

it("throw error if checksum value is not accurate", async () => {
it("throw error if checksum value is not accurate when not streaming", async () => {
vi.mocked(isStreaming).mockReturnValue(false);
const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);
try {
Expand All @@ -123,5 +138,18 @@ describe(validateChecksumFromResponse.name, () => {
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).toHaveBeenCalledTimes(1);
expect(createChecksumStream).not.toHaveBeenCalled();
});

it("return if checksum value is not accurate when streaming, as error will be thrown when stream is consumed", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
const incorrectChecksum = "incorrectChecksum";
const responseWithChecksum = getMockResponseWithHeader(mockResponseAlgorithms[0], incorrectChecksum);
await validateChecksumFromResponse(responseWithChecksum, mockOptions);
expect(getChecksumAlgorithmListForResponse).toHaveBeenCalledWith(mockResponseAlgorithms);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(getChecksum).not.toHaveBeenCalled();
expect(createChecksumStream).toHaveBeenCalledTimes(1);
});
});
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { HttpResponse } from "@smithy/protocol-http";
import { Checksum } from "@smithy/types";
import { createChecksumStream } from "@smithy/util-stream";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { getChecksum } from "./getChecksum";
import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";

export interface ValidateChecksumFromResponseOptions {
Expand All @@ -29,9 +32,20 @@ export const validateChecksumFromResponse = async (
const checksumFromResponse = responseHeaders[responseHeader];
if (checksumFromResponse) {
const checksumAlgorithmFn = selectChecksumAlgorithmFunction(algorithm as ChecksumAlgorithm, config);
const { streamHasher, base64Encoder } = config;
const checksum = await getChecksum(responseBody, { streamHasher, checksumAlgorithmFn, base64Encoder });
const { base64Encoder } = config;

if (isStreaming(responseBody)) {
createChecksumStream({
expectedChecksum: checksumFromResponse,
checksumSourceLocation: responseHeader,
checksum: new checksumAlgorithmFn() as Checksum,
source: responseBody,
base64Encoder,
});
return;
}

const checksum = await getChecksum(responseBody, { checksumAlgorithmFn, base64Encoder });
if (checksum === checksumFromResponse) {
// The checksum for response payload is valid.
break;
Expand Down
Loading