diff --git a/packages/middleware-recursion-detection/package.json b/packages/middleware-recursion-detection/package.json index af3c765c24f6..0d911040fdfd 100644 --- a/packages/middleware-recursion-detection/package.json +++ b/packages/middleware-recursion-detection/package.json @@ -24,6 +24,7 @@ "license": "Apache-2.0", "dependencies": { "@aws-sdk/types": "*", + "@aws/lambda-invoke-store": "^0.0.1", "@smithy/protocol-http": "^5.1.3", "@smithy/types": "^4.3.2", "tslib": "^2.6.2" diff --git a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts index 2707368499a7..7a06d8cf5f21 100644 --- a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.spec.ts @@ -1,3 +1,4 @@ +import { InvokeStore } from "@aws/lambda-invoke-store"; import { HttpRequest } from "@smithy/protocol-http"; import { afterAll, beforeEach, describe, expect, test as it, vi } from "vitest"; @@ -6,10 +7,12 @@ import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware"; describe(recursionDetectionMiddleware.name, () => { const mockNextHandler = vi.fn(); const originEnv = process.env; + const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id"; beforeEach(() => { vi.clearAllMocks(); + vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => undefined); process.env = {}; }); @@ -17,20 +20,58 @@ describe(recursionDetectionMiddleware.name, () => { process.env = originEnv; }); - it(`sets ${TRACE_ID_HEADER_NAME} header when function name and trace id environmental variables are set`, async () => { - process.env = { - AWS_LAMBDA_FUNCTION_NAME: "some-function", - _X_AMZN_TRACE_ID: "some-trace-id", - }; - const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); - await handler({ - input: {}, - request: new HttpRequest({}), + describe(`sets ${TRACE_ID_HEADER_NAME} header when function name and`, () => { + const mockTraceIdEnv = "trace-id-from-env"; + const mockTraceIdInvokeStore = "trace-id-from-invoke-store"; + + it("trace id environmental variables is set", async () => { + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + _X_AMZN_TRACE_ID: mockTraceIdEnv, + }; + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({}), + }); + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdEnv); + }); + + it("trace id value is set in InvokeStore", async () => { + vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => mockTraceIdInvokeStore); + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + }; + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({}), + }); + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdInvokeStore); + }); + + it("favors trace id value from InvokeStore over that from env variable", async () => { + vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => mockTraceIdInvokeStore); + process.env = { + AWS_LAMBDA_FUNCTION_NAME: "some-function", + _X_AMZN_TRACE_ID: mockTraceIdEnv, + }; + const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any); + await handler({ + input: {}, + request: new HttpRequest({}), + }); + const { calls } = (mockNextHandler as any).mock; + expect(calls.length).toBe(1); + const { request } = mockNextHandler.mock.calls[0][0]; + expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdInvokeStore); }); - const { calls } = (mockNextHandler as any).mock; - expect(calls.length).toBe(1); - const { request } = mockNextHandler.mock.calls[0][0]; - expect(request.headers[TRACE_ID_HEADER_NAME]).toBe("some-trace-id"); }); it(`should NOT set ${TRACE_ID_HEADER_NAME} header when function name environmental variable is NOT set`, async () => { diff --git a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts index 3b8f196889cf..21355280466d 100644 --- a/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts +++ b/packages/middleware-recursion-detection/src/recursionDetectionMiddleware.ts @@ -1,3 +1,4 @@ +import { InvokeStore } from "@aws/lambda-invoke-store"; import { HttpRequest } from "@smithy/protocol-http"; import { BuildHandler, @@ -31,7 +32,11 @@ export const recursionDetectionMiddleware = return next(args); } const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME]; - const traceId = process.env[ENV_TRACE_ID]; + + const traceIdFromEnv = process.env[ENV_TRACE_ID]; + const traceIdFromInvokeStore = InvokeStore.getXRayTraceId(); + const traceId = traceIdFromInvokeStore ?? traceIdFromEnv; + const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0; if (nonEmptyString(functionName) && nonEmptyString(traceId)) { request.headers[TRACE_ID_HEADER_NAME] = traceId; diff --git a/yarn.lock b/yarn.lock index 8ee2d1996cd1..b2e3bf3514c8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -24096,6 +24096,7 @@ __metadata: resolution: "@aws-sdk/middleware-recursion-detection@workspace:packages/middleware-recursion-detection" dependencies: "@aws-sdk/types": "npm:*" + "@aws/lambda-invoke-store": "npm:^0.0.1" "@smithy/protocol-http": "npm:^5.1.3" "@smithy/types": "npm:^4.3.2" "@tsconfig/recommended": "npm:1.0.1" @@ -25022,6 +25023,13 @@ __metadata: languageName: node linkType: hard +"@aws/lambda-invoke-store@npm:^0.0.1": + version: 0.0.1 + resolution: "@aws/lambda-invoke-store@npm:0.0.1" + checksum: 10c0/0bbf3060014a462177fb743e132e9b106a6743ad9cd905df4bd26e9ca8bfe2cc90473b03a79938fa908934e45e43f366f57af56a697991abda71d9ac92f5018f + languageName: node + linkType: hard + "@babel/code-frame@npm:^7.0.0, @babel/code-frame@npm:^7.12.13, @babel/code-frame@npm:^7.25.9, @babel/code-frame@npm:^7.26.0, @babel/code-frame@npm:^7.26.2": version: 7.26.2 resolution: "@babel/code-frame@npm:7.26.2"