Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/middleware-recursion-detection/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,12 @@
"downlevel-dts": "0.10.1",
"rimraf": "3.0.2",
"typescript": "~5.8.3"
},
"browser": {
"./dist-es/recursionDetectionMiddleware": "./dist-es/recursionDetectionMiddleware.no-op"
},
"react-native": {
"./dist-es/recursionDetectionMiddleware": "./dist-es/recursionDetectionMiddleware.no-op",
"./dist-cjs/recursionDetectionMiddleware": "./dist-cjs/recursionDetectionMiddleware.no-op"
}
}
12 changes: 12 additions & 0 deletions packages/middleware-recursion-detection/src/configuration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { AbsoluteLocation, BuildHandlerOptions } from "@smithy/types";

/**
* @internal
*/
export const recursionDetectionMiddlewareOptions: BuildHandlerOptions & AbsoluteLocation = {
step: "build",
tags: ["RECURSION_DETECTION"],
name: "recursionDetectionMiddleware",
override: true,
priority: "low",
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { Pluggable } from "@smithy/types";

import { recursionDetectionMiddlewareOptions } from "./configuration";
import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware";

// @internal
/**
* @internal
*/

export const getRecursionDetectionPlugin = (): Pluggable<any, any> => ({
applyToStack: (clientStack) => {
clientStack.add(recursionDetectionMiddleware(), recursionDetectionMiddlewareOptions);
},
});
74 changes: 2 additions & 72 deletions packages/middleware-recursion-detection/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,72 +1,2 @@
import { HttpRequest } from "@smithy/protocol-http";
import {
AbsoluteLocation,
BuildHandler,
BuildHandlerArguments,
BuildHandlerOptions,
BuildHandlerOutput,
BuildMiddleware,
MetadataBearer,
Pluggable,
} from "@smithy/types";

const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id";
const ENV_LAMBDA_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME";
const ENV_TRACE_ID = "_X_AMZN_TRACE_ID";

interface PreviouslyResolved {
runtime: string;
}

/**
* Inject to trace ID to request header to detect recursion invocation in Lambda.
* @internal
*/
export const recursionDetectionMiddleware =
(options: PreviouslyResolved): BuildMiddleware<any, any> =>
<Output extends MetadataBearer>(next: BuildHandler<any, Output>): BuildHandler<any, Output> =>
async (args: BuildHandlerArguments<any>): Promise<BuildHandlerOutput<Output>> => {
const { request } = args;
if (!HttpRequest.isInstance(request) || options.runtime !== "node") {
return next(args);
}
const traceIdHeader =
Object.keys(request.headers ?? {}).find((h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase()) ??
TRACE_ID_HEADER_NAME;

if (request.headers.hasOwnProperty(traceIdHeader)) {
return next(args);
}
const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME];
const traceId = process.env[ENV_TRACE_ID];
const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0;
if (nonEmptyString(functionName) && nonEmptyString(traceId)) {
request.headers[TRACE_ID_HEADER_NAME] = traceId;
}
return next({
...args,
request,
});
};

// @internal
/**
* @internal
*/
export const addRecursionDetectionMiddlewareOptions: BuildHandlerOptions & AbsoluteLocation = {
step: "build",
tags: ["RECURSION_DETECTION"],
name: "recursionDetectionMiddleware",
override: true,
priority: "low",
};

// @internal
/**
* @internal
*/
export const getRecursionDetectionPlugin = (options: PreviouslyResolved): Pluggable<any, any> => ({
applyToStack: (clientStack) => {
clientStack.add(recursionDetectionMiddleware(options), addRecursionDetectionMiddlewareOptions);
},
});
export * from "./getRecursionDetectionPlugin";
export * from "./recursionDetectionMiddleware";
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import {
BuildHandler,
BuildHandlerArguments,
BuildHandlerOutput,
BuildMiddleware,
MetadataBearer,
} from "@smithy/types";

/**
* No-op middleware for runtimes outside of Node.js
* @internal
*/
export const recursionDetectionMiddleware =
(): BuildMiddleware<any, any> =>
<Output extends MetadataBearer>(next: BuildHandler<any, Output>): BuildHandler<any, Output> =>
async (args: BuildHandlerArguments<any>): Promise<BuildHandlerOutput<Output>> =>
next(args);
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { HttpRequest } from "@smithy/protocol-http";
import { afterAll, beforeEach, describe, expect, test as it, vi } from "vitest";

import { recursionDetectionMiddleware } from "./index";
import { recursionDetectionMiddleware } from "./recursionDetectionMiddleware";

describe(recursionDetectionMiddleware.name, () => {
const mockNextHandler = vi.fn();
Expand All @@ -22,7 +22,7 @@ describe(recursionDetectionMiddleware.name, () => {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
Expand All @@ -37,7 +37,7 @@ describe(recursionDetectionMiddleware.name, () => {
process.env = {
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
Expand All @@ -54,7 +54,7 @@ describe(recursionDetectionMiddleware.name, () => {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({
Expand All @@ -75,7 +75,7 @@ describe(recursionDetectionMiddleware.name, () => {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({
Expand All @@ -100,7 +100,7 @@ describe(recursionDetectionMiddleware.name, () => {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({
Expand All @@ -125,7 +125,7 @@ describe(recursionDetectionMiddleware.name, () => {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "node" })(mockNextHandler, {} as any);
const handler = recursionDetectionMiddleware()(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({
Expand All @@ -144,21 +144,4 @@ describe(recursionDetectionMiddleware.name, () => {
expect(existingTraceHeader).toBeDefined();
expect(request.headers[existingTraceHeader!]).toBe("some-real-trace-id");
});

it("has no effect for browser runtime", async () => {
process.env = {
AWS_LAMBDA_FUNCTION_NAME: "some-function",
_X_AMZN_TRACE_ID: "some-trace-id",
};
const handler = recursionDetectionMiddleware({ runtime: "browser" })(mockNextHandler, {} as any);
await handler({
input: {},
request: new HttpRequest({}),
});

const { calls } = (mockNextHandler as any).mock;
expect(calls.length).toBe(1);
const { request } = mockNextHandler.mock.calls[0][0];
expect(request.headers[TRACE_ID_HEADER_NAME]).toBeUndefined();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { HttpRequest } from "@smithy/protocol-http";
import {
BuildHandler,
BuildHandlerArguments,
BuildHandlerOutput,
BuildMiddleware,
MetadataBearer,
} from "@smithy/types";

const TRACE_ID_HEADER_NAME = "X-Amzn-Trace-Id";
const ENV_LAMBDA_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME";
const ENV_TRACE_ID = "_X_AMZN_TRACE_ID";

/**
* Inject to trace ID to request header to detect recursion invocation in Lambda.
* @internal
*/
export const recursionDetectionMiddleware =
(): BuildMiddleware<any, any> =>
<Output extends MetadataBearer>(next: BuildHandler<any, Output>): BuildHandler<any, Output> =>
async (args: BuildHandlerArguments<any>): Promise<BuildHandlerOutput<Output>> => {
const { request } = args;
if (!HttpRequest.isInstance(request)) {
return next(args);
}
const traceIdHeader =
Object.keys(request.headers ?? {}).find((h) => h.toLowerCase() === TRACE_ID_HEADER_NAME.toLowerCase()) ??
TRACE_ID_HEADER_NAME;

if (request.headers.hasOwnProperty(traceIdHeader)) {
return next(args);
}
const functionName = process.env[ENV_LAMBDA_FUNCTION_NAME];
const traceId = process.env[ENV_TRACE_ID];
const nonEmptyString = (str: unknown): str is string => typeof str === "string" && str.length > 0;
if (nonEmptyString(functionName) && nonEmptyString(traceId)) {
request.headers[TRACE_ID_HEADER_NAME] = traceId;
}
return next({
...args,
request,
});
};
Loading