diff --git a/packages/middleware-recursion-detection/src/index.spec.ts b/packages/middleware-recursion-detection/src/index.spec.ts index 5b2452816021..8f5a81c2776d 100644 --- a/packages/middleware-recursion-detection/src/index.spec.ts +++ b/packages/middleware-recursion-detection/src/index.spec.ts @@ -70,6 +70,81 @@ describe(recursionDetectionMiddleware.name, () => { expect(request.headers[TRACE_ID_HEADER_NAME]).toBe("some-real-trace-id"); }); + it(`should NOT set ${TRACE_ID_HEADER_NAME} header when the header is already set with some other casing`, async () => { + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + _X_AMZN_TRACE_ID: "some-trace-id", + }; + const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({ + headers: { + ["x-AmZn-TrAcE-iD"]: "some-real-trace-id", + }, + }), + }); + + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + const existingTraceHeader = Object.keys(request.headers).find( + (h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase() + ); + expect(existingTraceHeader).toBeDefined(); + expect(request.headers[existingTraceHeader!]).toBe("some-real-trace-id"); + }); + + it(`should NOT set ${TRACE_ID_HEADER_NAME} header when the header is already set with alternating case`, async () => { + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + _X_AMZN_TRACE_ID: "some-trace-id", + }; + const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({ + headers: { + "X-aMzN-tRaCe-Id": "some-real-trace-id", + }, + }), + }); + + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + const existingTraceHeader = Object.keys(request.headers).find( + (h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase() + ); + expect(existingTraceHeader).toBeDefined(); + expect(request.headers[existingTraceHeader!]).toBe("some-real-trace-id"); + }); + + it(`should NOT set ${TRACE_ID_HEADER_NAME} header when the header is already set with all uppercase`, async () => { + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + _X_AMZN_TRACE_ID: "some-trace-id", + }; + const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({ + headers: { + "X-AMZN-TRACE-ID": "some-real-trace-id", + }, + }), + }); + + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + const existingTraceHeader = Object.keys(request.headers).find( + (h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase() + ); + expect(existingTraceHeader).toBeDefined(); + expect(request.headers[existingTraceHeader!]).toBe("some-real-trace-id"); + }); + it("has no effect for browser runtime", async () => { process.env = { AWS_LAMBDA_FUNCTION_NAME: "some-function", diff --git a/packages/middleware-recursion-detection/src/index.ts b/packages/middleware-recursion-detection/src/index.ts index 42fb46ecfdc9..9030c34443c8 100644 --- a/packages/middleware-recursion-detection/src/index.ts +++ b/packages/middleware-recursion-detection/src/index.ts @@ -27,14 +27,16 @@ export const recursionDetectionMiddleware = (next: BuildHandler): BuildHandler => async (args: BuildHandlerArguments): Promise> => { const { request } = args; - if ( - !HttpRequest.isInstance(request) || - options.runtime !== "node" || - request.headers.hasOwnProperty(TRACE_ID_HEADER_NAME) - ) { + if (!HttpRequest.isInstance(request) || options.runtime !== "node") { return next(args); } + const traceIdHeader = + Object.keys(request.headers ?? {}).find((h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase()) ?? + TRACE_ID_HEADER_NAME; + if (request.headers.hasOwnProperty(traceIdHeader)) { + return next(args); + } const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME]; const traceId = process.env[ENV_TRACE_ID]; const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0;