Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/middleware-recursion-detection/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"license": "Apache-2.0",
"dependencies": {
"@aws-sdk/types": "*",
"@aws/lambda-invoke-store": "^0.0.1",
"@smithy/protocol-http": "^5.1.3",
"@smithy/types": "^4.3.2",
"tslib": "^2.6.2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { InvokeStore } from "@aws/lambda-invoke-store";
import { HttpRequest } from "@smithy/protocol-http";
import { afterAll, beforeEach, describe, expect, test as it, vi } from "vitest";

Expand All @@ -6,31 +7,71 @@ import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware";
describe(recursionDetectionMiddleware.name, () => {
const mockNextHandler = vi.fn();
const originEnv = process.env;

const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id";

beforeEach(() => {
vi.clearAllMocks();
vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => undefined);
process.env = {};
});

afterAll(() => {
process.env = originEnv;
});

it(`sets ${TRACE_ID_HEADER_NAME} header when function name and trace id environmental variables are set`, async () => {
process.env = {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
describe(`sets ${TRACE_ID_HEADER_NAME} header when function name and`, () => {
const mockTraceIdEnv = "trace-id-from-env";
const mockTraceIdInvokeStore = "trace-id-from-invoke-store";

it("trace id environmental variables is set", async () => {
process.env = {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: mockTraceIdEnv,
};
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
});
const { calls } = (mockNextHandler as any).mock;
expect(calls.length).toBe(1);
const { request } = mockNextHandler.mock.calls[0][0];
expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdEnv);
});

it("trace id value is set in InvokeStore", async () => {
vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => mockTraceIdInvokeStore);
process.env = {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
};
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
});
const { calls } = (mockNextHandler as any).mock;
expect(calls.length).toBe(1);
const { request } = mockNextHandler.mock.calls[0][0];
expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdInvokeStore);
});

it("favors trace id value from InvokeStore over that from env variable", async () => {
vi.spyOn(InvokeStore, "getXRayTraceId").mockImplementation(() => mockTraceIdInvokeStore);
process.env = {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: mockTraceIdEnv,
};
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
});
const { calls } = (mockNextHandler as any).mock;
expect(calls.length).toBe(1);
const { request } = mockNextHandler.mock.calls[0][0];
expect(request.headers[TRACE_ID_HEADER_NAME]).toBe(mockTraceIdInvokeStore);
});
const { calls } = (mockNextHandler as any).mock;
expect(calls.length).toBe(1);
const { request } = mockNextHandler.mock.calls[0][0];
expect(request.headers[TRACE_ID_HEADER_NAME]).toBe("some-trace-id");
});

it(`should NOT set ${TRACE_ID_HEADER_NAME} header when function name environmental variable is NOT set`, async () => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { InvokeStore } from "@aws/lambda-invoke-store";
import { HttpRequest } from "@smithy/protocol-http";
import {
BuildHandler,
Expand Down Expand Up @@ -31,7 +32,11 @@ export const recursionDetectionMiddleware =
return next(args);
}
const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME];
const traceId = process.env[ENV_TRACE_ID];

const traceIdFromEnv = process.env[ENV_TRACE_ID];
const traceIdFromInvokeStore = InvokeStore.getXRayTraceId();
const traceId = traceIdFromInvokeStore ?? traceIdFromEnv;

const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0;
if (nonEmptyString(functionName) && nonEmptyString(traceId)) {
request.headers[TRACE_ID_HEADER_NAME] = traceId;
Expand Down
8 changes: 8 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -24096,6 +24096,7 @@ __metadata:
resolution: "@aws-sdk/middleware-recursion-detection@workspace:packages/middleware-recursion-detection"
dependencies:
"@aws-sdk/types": "npm:*"
"@aws/lambda-invoke-store": "npm:^0.0.1"
"@smithy/protocol-http": "npm:^5.1.3"
"@smithy/types": "npm:^4.3.2"
"@tsconfig/recommended": "npm:1.0.1"
Expand Down Expand Up @@ -25022,6 +25023,13 @@ __metadata:
languageName: node
linkType: hard

"@aws/lambda-invoke-store@npm:^0.0.1":
version: 0.0.1
resolution: "@aws/lambda-invoke-store@npm:0.0.1"
checksum: 10c0/0bbf3060014a462177fb743e132e9b106a6743ad9cd905df4bd26e9ca8bfe2cc90473b03a79938fa908934e45e43f366f57af56a697991abda71d9ac92f5018f
languageName: node
linkType: hard

"@babel/code-frame@npm:^7.0.0, @babel/code-frame@npm:^7.12.13, @babel/code-frame@npm:^7.25.9, @babel/code-frame@npm:^7.26.0, @babel/code-frame@npm:^7.26.2":
version: 7.26.2
resolution: "@babel/code-frame@npm:7.26.2"
Expand Down
Loading