diff --git a/packages/middleware-recursion-detection/package.json b/packages/middleware-recursion-detection/package.json index 4e07eee6d8d0..af3c765c24f6 100644 --- a/packages/middleware-recursion-detection/package.json +++ b/packages/middleware-recursion-detection/package.json @@ -53,5 +53,9 @@ "downlevel-dts": "0.10.1", "rimraf": "3.0.2", "typescript": "~5.8.3" - } + }, + "browser": { + "./dist-es/recursionDetectionMiddleware": "./dist-es/recursionDetectionMiddleware.browser" + }, + "react-native": {} } diff --git a/packages/middleware-recursion-detection/src/configuration.ts b/packages/middleware-recursion-detection/src/configuration.ts new file mode 100644 index 000000000000..7f7aba873338 --- /dev/null +++ b/packages/middleware-recursion-detection/src/configuration.ts @@ -0,0 +1,12 @@ +import { AbsoluteLocation, BuildHandlerOptions } from "@smithy/types"; + +/** + * @internal + */ +export const recursionDetectionMiddlewareOptions: BuildHandlerOptions & AbsoluteLocation = { + step: "build", + tags: ["RECURSION_DETECTION"], + name: "recursionDetectionMiddleware", + override: true, + priority: "low", +}; diff --git a/packages/middleware-recursion-detection/src/getRecursionDetectionPlugin.ts b/packages/middleware-recursion-detection/src/getRecursionDetectionPlugin.ts new file mode 100644 index 000000000000..d59fa78c5806 --- /dev/null +++ b/packages/middleware-recursion-detection/src/getRecursionDetectionPlugin.ts @@ -0,0 +1,13 @@ +import { Pluggable } from "@smithy/types"; + +import { recursionDetectionMiddlewareOptions } from "./configuration"; +import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware"; + +/** + * @internal + */ +export const getRecursionDetectionPlugin = (options: any): Pluggable => ({ + applyToStack: (clientStack) => { + clientStack.add(recursionDetectionMiddleware(), recursionDetectionMiddlewareOptions); + }, +}); diff --git a/packages/middleware-recursion-detection/src/index.ts b/packages/middleware-recursion-detection/src/index.ts index 9030c34443c8..88e92b67b71e 100644 --- a/packages/middleware-recursion-detection/src/index.ts +++ b/packages/middleware-recursion-detection/src/index.ts @@ -1,72 +1,2 @@ -import { HttpRequest } from "@smithy/protocol-http"; -import { - AbsoluteLocation, - BuildHandler, - BuildHandlerArguments, - BuildHandlerOptions, - BuildHandlerOutput, - BuildMiddleware, - MetadataBearer, - Pluggable, -} from "@smithy/types"; - -const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id"; -const ENV_LAMBDA_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME"; -const ENV_TRACE_ID = "_X_AMZN_TRACE_ID"; - -interface PreviouslyResolved { - runtime: string; -} - -/** - * Inject to trace ID to request header to detect recursion invocation in Lambda. - * @internal - */ -export const recursionDetectionMiddleware = - (options: PreviouslyResolved): BuildMiddleware => - (next: BuildHandler): BuildHandler => - async (args: BuildHandlerArguments): Promise> => { - const { request } = args; - if (!HttpRequest.isInstance(request) || options.runtime !== "node") { - return next(args); - } - const traceIdHeader = - Object.keys(request.headers ?? {}).find((h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase()) ?? - TRACE_ID_HEADER_NAME; - - if (request.headers.hasOwnProperty(traceIdHeader)) { - return next(args); - } - const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME]; - const traceId = process.env[ENV_TRACE_ID]; - const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0; - if (nonEmptyString(functionName) && nonEmptyString(traceId)) { - request.headers[TRACE_ID_HEADER_NAME] = traceId; - } - return next({ - ...args, - request, - }); - }; - -// @internal -/** - * @internal - */ -export const addRecursionDetectionMiddlewareOptions: BuildHandlerOptions & AbsoluteLocation = { - step: "build", - tags: ["RECURSION_DETECTION"], - name: "recursionDetectionMiddleware", - override: true, - priority: "low", -}; - -// @internal -/** - * @internal - */ -export const getRecursionDetectionPlugin = (options: PreviouslyResolved): Pluggable => ({ - applyToStack: (clientStack) => { - clientStack.add(recursionDetectionMiddleware(options), addRecursionDetectionMiddlewareOptions); - }, -}); +export * from "./getRecursionDetectionPlugin"; +export * from "./recursionDetectionMiddleware"; diff --git a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.browser.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.browser.ts new file mode 100644 index 000000000000..6d598e6731e8 --- /dev/null +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.browser.ts @@ -0,0 +1,17 @@ +import { + BuildHandler, + BuildHandlerArguments, + BuildHandlerOutput, + BuildMiddleware, + MetadataBearer, +} from "@smithy/types"; + +/** + * No-op middleware for runtimes outside of Node.js + * @internal + */ +export const recursionDetectionMiddleware = + (): BuildMiddleware => + (next: BuildHandler): BuildHandler => + async (args: BuildHandlerArguments): Promise> => + next(args); diff --git a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.native.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.native.ts new file mode 100644 index 000000000000..6d598e6731e8 --- /dev/null +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.native.ts @@ -0,0 +1,17 @@ +import { + BuildHandler, + BuildHandlerArguments, + BuildHandlerOutput, + BuildMiddleware, + MetadataBearer, +} from "@smithy/types"; + +/** + * No-op middleware for runtimes outside of Node.js + * @internal + */ +export const recursionDetectionMiddleware = + (): BuildMiddleware => + (next: BuildHandler): BuildHandler => + async (args: BuildHandlerArguments): Promise> => + next(args); diff --git a/packages/middleware-recursion-detection/src/index.spec.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts similarity index 78% rename from packages/middleware-recursion-detection/src/index.spec.ts rename to packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts index 8f5a81c2776d..2707368499a7 100644 --- a/packages/middleware-recursion-detection/src/index.spec.ts +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts @@ -1,7 +1,7 @@ import { HttpRequest } from "@smithy/protocol-http"; import { afterAll, beforeEach, describe, expect, test as it, vi } from "vitest"; -import { recursionDetectionMiddleware } from "./index"; +import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware"; describe(recursionDetectionMiddleware.name, () => { const mockNextHandler = vi.fn(); @@ -22,7 +22,7 @@ describe(recursionDetectionMiddleware.name, () => { AWS_LAMBDA_FUNCTION_NAME: "some-function", _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({}), @@ -37,7 +37,7 @@ describe(recursionDetectionMiddleware.name, () => { process.env = { _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({}), @@ -54,7 +54,7 @@ describe(recursionDetectionMiddleware.name, () => { AWS_LAMBDA_FUNCTION_NAME: "some-function", _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({ @@ -75,7 +75,7 @@ describe(recursionDetectionMiddleware.name, () => { AWS_LAMBDA_FUNCTION_NAME: "some-function", _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({ @@ -100,7 +100,7 @@ describe(recursionDetectionMiddleware.name, () => { AWS_LAMBDA_FUNCTION_NAME: "some-function", _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({ @@ -125,7 +125,7 @@ describe(recursionDetectionMiddleware.name, () => { AWS_LAMBDA_FUNCTION_NAME: "some-function", _X_AMZN_TRACE_ID: "some-trace-id", }; - const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); await handler({ input: {}, request: new HttpRequest({ @@ -144,21 +144,4 @@ describe(recursionDetectionMiddleware.name, () => { expect(existingTraceHeader).toBeDefined(); expect(request.headers[existingTraceHeader!]).toBe("some-real-trace-id"); }); - - it("has no effect for browser runtime", async () => { - process.env = { - AWS_LAMBDA_FUNCTION_NAME: "some-function", - _X_AMZN_TRACE_ID: "some-trace-id", - }; - const handler = recursionDetectionMiddleware({ runtime: "browser" })(mockNextHandler, {} as any); - await handler({ - input: {}, - request: new HttpRequest({}), - }); - - const { calls } = (mockNextHandler as any).mock; - expect(calls.length).toBe(1); - const { request } = mockNextHandler.mock.calls[0][0]; - expect(request.headers[TRACE_ID_HEADER_NAME]).toBeUndefined(); - }); }); diff --git a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts new file mode 100644 index 000000000000..3b8f196889cf --- /dev/null +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts @@ -0,0 +1,43 @@ +import { HttpRequest } from "@smithy/protocol-http"; +import { + BuildHandler, + BuildHandlerArguments, + BuildHandlerOutput, + BuildMiddleware, + MetadataBearer, +} from "@smithy/types"; + +const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id"; +const ENV_LAMBDA_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME"; +const ENV_TRACE_ID = "_X_AMZN_TRACE_ID"; + +/** + * Inject to trace ID to request header to detect recursion invocation in Lambda. + * @internal + */ +export const recursionDetectionMiddleware = + (): BuildMiddleware => + (next: BuildHandler): BuildHandler => + async (args: BuildHandlerArguments): Promise> => { + const { request } = args; + if (!HttpRequest.isInstance(request)) { + return next(args); + } + const traceIdHeader = + Object.keys(request.headers ?? {}).find((h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase()) ?? + TRACE_ID_HEADER_NAME; + + if (request.headers.hasOwnProperty(traceIdHeader)) { + return next(args); + } + const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME]; + const traceId = process.env[ENV_TRACE_ID]; + const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0; + if (nonEmptyString(functionName) && nonEmptyString(traceId)) { + request.headers[TRACE_ID_HEADER_NAME] = traceId; + } + return next({ + ...args, + request, + }); + };