diff --git a/docs/concepts/middleware.mdx b/docs/concepts/middleware.mdx new file mode 100644 index 000000000..835d10566 --- /dev/null +++ b/docs/concepts/middleware.mdx @@ -0,0 +1,307 @@ +--- +title: "Middleware" +description: "Transform and intercept events in AG-UI agents" +--- + +# Middleware + +Middleware in AG-UI provides a powerful way to transform, filter, and augment the event streams that flow through agents. It enables you to add cross-cutting concerns like logging, authentication, rate limiting, and event filtering without modifying the core agent logic. + +## What is Middleware? + +Middleware sits between the agent execution and the event consumer, allowing you to: + +1. **Transform events** – Modify or enhance events as they flow through the pipeline +2. **Filter events** – Selectively allow or block certain events +3. **Add metadata** – Inject additional context or tracking information +4. **Handle errors** – Implement custom error recovery strategies +5. **Monitor execution** – Add logging, metrics, or debugging capabilities + +## How Middleware Works + +Middleware forms a chain where each middleware wraps the next, creating layers of functionality. When an agent runs, the event stream flows through each middleware in sequence. + +```typescript +import { AbstractAgent } from "@ag-ui/client" + +const agent = new MyAgent() + +// Middleware chain: logging -> auth -> filter -> agent +agent.use(loggingMiddleware, authMiddleware, filterMiddleware) + +// When agent runs, events flow through all middleware +await agent.runAgent() +``` + +## Function-Based Middleware + +For simple transformations, you can use function-based middleware. This is the most concise way to add middleware: + +```typescript +import { MiddlewareFunction } from "@ag-ui/client" +import { EventType } from "@ag-ui/core" + +const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + return { + ...event, + delta: `[AI]: ${event.delta}` + } + } + return event + }) + ) +} + +agent.use(prefixMiddleware) +``` + +## Class-Based Middleware + +For more complex scenarios requiring state or configuration, use class-based middleware: + +```typescript +import { Middleware } from "@ag-ui/client" +import { Observable } from "rxjs" +import { tap } from "rxjs/operators" + +class MetricsMiddleware extends Middleware { + private eventCount = 0 + + constructor(private metricsService: MetricsService) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const startTime = Date.now() + + return next.run(input).pipe( + tap(event => { + this.eventCount++ + this.metricsService.recordEvent(event.type) + }), + finalize(() => { + const duration = Date.now() - startTime + this.metricsService.recordDuration(duration) + this.metricsService.recordEventCount(this.eventCount) + }) + ) + } +} + +agent.use(new MetricsMiddleware(metricsService)) +``` + +## Built-in Middleware + +AG-UI provides several built-in middleware components for common use cases: + +### FilterToolCallsMiddleware + +Filter tool calls based on allowed or disallowed lists: + +```typescript +import { FilterToolCallsMiddleware } from "@ag-ui/client" + +// Only allow specific tools +const allowedFilter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate"] +}) + +// Or block specific tools +const blockedFilter = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["delete", "modify"] +}) + +agent.use(allowedFilter) +``` + +## Middleware Patterns + +### Logging Middleware + +```typescript +const loggingMiddleware: MiddlewareFunction = (input, next) => { + console.log("Request:", input.messages) + + return next.run(input).pipe( + tap(event => console.log("Event:", event.type)), + catchError(error => { + console.error("Error:", error) + throw error + }) + ) +} +``` + +### Authentication Middleware + +```typescript +class AuthMiddleware extends Middleware { + constructor(private apiKey: string) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Add authentication to the context + const authenticatedInput = { + ...input, + context: [ + ...input.context, + { type: "auth", apiKey: this.apiKey } + ] + } + + return next.run(authenticatedInput) + } +} +``` + +### Rate Limiting Middleware + +```typescript +class RateLimitMiddleware extends Middleware { + private lastCall = 0 + + constructor(private minInterval: number) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const now = Date.now() + const timeSinceLastCall = now - this.lastCall + + if (timeSinceLastCall < this.minInterval) { + const delay = this.minInterval - timeSinceLastCall + return timer(delay).pipe( + switchMap(() => { + this.lastCall = Date.now() + return next.run(input) + }) + ) + } + + this.lastCall = now + return next.run(input) + } +} +``` + +## Combining Middleware + +You can combine multiple middleware to create sophisticated processing pipelines: + +```typescript +// Function middleware for simple logging +const logMiddleware: MiddlewareFunction = (input, next) => { + console.log(`Starting run ${input.runId}`) + return next.run(input) +} + +// Class middleware for authentication +const authMiddleware = new AuthMiddleware(apiKey) + +// Built-in middleware for filtering +const filterMiddleware = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "summarize"] +}) + +// Apply all middleware in order +agent.use( + logMiddleware, // First: log the request + authMiddleware, // Second: add authentication + filterMiddleware // Third: filter tool calls +) +``` + +## Execution Order + +Middleware executes in the order it's added, with each middleware wrapping the next: + +1. First middleware receives the original input +2. It can modify the input before passing to the next middleware +3. Each middleware processes events from the next in the chain +4. The final middleware calls the actual agent + +```typescript +agent.use(middleware1, middleware2, middleware3) + +// Execution flow: +// → middleware1 +// → middleware2 +// → middleware3 +// → agent.run() +// ← events flow back through middleware3 +// ← events flow back through middleware2 +// ← events flow back through middleware1 +``` + +## Best Practices + +1. **Keep middleware focused** – Each middleware should have a single responsibility +2. **Handle errors gracefully** – Use RxJS error handling operators +3. **Avoid blocking operations** – Use async patterns for I/O operations +4. **Document side effects** – Clearly indicate if middleware modifies state +5. **Test middleware independently** – Write unit tests for each middleware +6. **Consider performance** – Be mindful of processing overhead in the event stream + +## Advanced Use Cases + +### Conditional Middleware + +Apply middleware based on runtime conditions: + +```typescript +const conditionalMiddleware: MiddlewareFunction = (input, next) => { + if (input.context.some(c => c.type === "debug")) { + // Apply debug logging + return next.run(input).pipe( + tap(event => console.debug(event)) + ) + } + return next.run(input) +} +``` + +### Event Transformation + +Transform specific event types: + +```typescript +const transformMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TOOL_CALL_START) { + // Add timestamp to tool calls + return { + ...event, + metadata: { + ...event.metadata, + timestamp: Date.now() + } + } + } + return event + }) + ) +} +``` + +### Stream Control + +Control the flow of events: + +```typescript +const throttleMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + // Throttle text message chunks to prevent overwhelming the UI + throttleTime(50, undefined, { leading: true, trailing: true }) + ) +} +``` + +## Conclusion + +Middleware provides a flexible and powerful way to extend AG-UI agents without modifying their core logic. Whether you need simple event transformation or complex stateful processing, the middleware system offers the tools to build robust, maintainable agent applications. \ No newline at end of file diff --git a/docs/docs.json b/docs/docs.json index 664f672bd..f63ed72b7 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -40,6 +40,7 @@ "concepts/architecture", "concepts/events", "concepts/agents", + "concepts/middleware", "concepts/messages", "concepts/state", "concepts/tools" @@ -85,6 +86,7 @@ "sdk/js/client/overview", "sdk/js/client/abstract-agent", "sdk/js/client/http-agent", + "sdk/js/client/middleware", "sdk/js/client/subscriber" ] }, diff --git a/docs/sdk/js/client/abstract-agent.mdx b/docs/sdk/js/client/abstract-agent.mdx index 3738b0aae..7c6acafb5 100644 --- a/docs/sdk/js/client/abstract-agent.mdx +++ b/docs/sdk/js/client/abstract-agent.mdx @@ -95,6 +95,36 @@ subscribe(subscriber: AgentSubscriber): { unsubscribe: () => void } Returns an object with an `unsubscribe()` method to remove the subscriber when no longer needed. +### use() + +Adds middleware to the agent's event processing pipeline. + +```typescript +use(...middlewares: (Middleware | MiddlewareFunction)[]): this +``` + +Middleware can be either: +- **Function middleware**: Simple functions that transform the event stream +- **Class middleware**: Instances of the `Middleware` class for stateful operations + +```typescript +// Function middleware +agent.use((input, next) => { + console.log("Processing:", input.runId); + return next.run(input); +}); + +// Class middleware +agent.use(new FilterToolCallsMiddleware({ + allowedToolCalls: ["search"] +})); + +// Chain multiple middleware +agent.use(loggingMiddleware, authMiddleware, filterMiddleware); +``` + +Middleware executes in the order added, with each wrapping the next. See the [Middleware documentation](/sdk/js/client/middleware) for more details. + ### abortRun() Cancels the current agent execution. diff --git a/docs/sdk/js/client/middleware.mdx b/docs/sdk/js/client/middleware.mdx new file mode 100644 index 000000000..31462257b --- /dev/null +++ b/docs/sdk/js/client/middleware.mdx @@ -0,0 +1,408 @@ +--- +title: "Middleware" +description: "Event stream transformation and filtering for AG-UI agents" +--- + +# Middleware + +The middleware system in `@ag-ui/client` provides a powerful way to transform, filter, and augment event streams flowing through agents. Middleware can intercept and modify events, add logging, implement authentication, filter tool calls, and more. + +```typescript +import { Middleware, MiddlewareFunction, FilterToolCallsMiddleware } from "@ag-ui/client" +``` + +## Types + +### MiddlewareFunction + +A function that transforms the event stream. + +```typescript +type MiddlewareFunction = ( + input: RunAgentInput, + next: AbstractAgent +) => Observable +``` + +### Middleware + +Abstract base class for creating middleware. + +```typescript +abstract class Middleware { + abstract run( + input: RunAgentInput, + next: AbstractAgent + ): Observable +} +``` + +## Function-Based Middleware + +The simplest way to create middleware is with a function. Function middleware is ideal for stateless transformations. + +### Basic Example + +```typescript +const loggingMiddleware: MiddlewareFunction = (input, next) => { + console.log(`[${new Date().toISOString()}] Starting run ${input.runId}`) + + return next.run(input).pipe( + tap(event => console.log(`Event: ${event.type}`)), + finalize(() => console.log(`Run ${input.runId} completed`)) + ) +} + +agent.use(loggingMiddleware) +``` + +### Transforming Events + +```typescript +const prefixMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + map(event => { + if (event.type === EventType.TEXT_MESSAGE_CHUNK) { + return { + ...event, + delta: `[Assistant]: ${event.delta}` + } + } + return event + }) + ) +} +``` + +### Error Handling + +```typescript +const errorMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + catchError(error => { + console.error("Agent error:", error) + + // Return error event + return of({ + type: EventType.RUN_ERROR, + message: error.message + } as BaseEvent) + }) + ) +} +``` + +## Class-Based Middleware + +For stateful operations or complex logic, extend the `Middleware` class. + +### Basic Implementation + +```typescript +class CounterMiddleware extends Middleware { + private totalEvents = 0 + + run(input: RunAgentInput, next: AbstractAgent): Observable { + let runEvents = 0 + + return next.run(input).pipe( + tap(() => { + runEvents++ + this.totalEvents++ + }), + finalize(() => { + console.log(`Run events: ${runEvents}, Total: ${this.totalEvents}`) + }) + ) + } +} + +agent.use(new CounterMiddleware()) +``` + +### Configuration-Based Middleware + +```typescript +class AuthMiddleware extends Middleware { + constructor( + private apiKey: string, + private headerName: string = "Authorization" + ) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Add authentication to context + const authenticatedInput = { + ...input, + context: [ + ...input.context, + { + type: "auth", + [this.headerName]: `Bearer ${this.apiKey}` + } + ] + } + + return next.run(authenticatedInput) + } +} + +agent.use(new AuthMiddleware(process.env.API_KEY)) +``` + +## Built-in Middleware + +### FilterToolCallsMiddleware + +Filters tool calls based on allowed or disallowed lists. + +```typescript +import { FilterToolCallsMiddleware } from "@ag-ui/client" +``` + +#### Configuration + +```typescript +type FilterToolCallsConfig = + | { allowedToolCalls: string[]; disallowedToolCalls?: never } + | { disallowedToolCalls: string[]; allowedToolCalls?: never } +``` + +#### Allow Specific Tools + +```typescript +const allowFilter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate", "summarize"] +}) + +agent.use(allowFilter) +``` + +#### Block Specific Tools + +```typescript +const blockFilter = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["delete", "modify", "execute"] +}) + +agent.use(blockFilter) +``` + +## Middleware Patterns + +### Timing Middleware + +```typescript +const timingMiddleware: MiddlewareFunction = (input, next) => { + const startTime = performance.now() + + return next.run(input).pipe( + finalize(() => { + const duration = performance.now() - startTime + console.log(`Execution time: ${duration.toFixed(2)}ms`) + }) + ) +} +``` + +### Rate Limiting + +```typescript +class RateLimitMiddleware extends Middleware { + private lastCall = 0 + + constructor(private minInterval: number) { + super() + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const now = Date.now() + const elapsed = now - this.lastCall + + if (elapsed < this.minInterval) { + // Delay the execution + return timer(this.minInterval - elapsed).pipe( + switchMap(() => { + this.lastCall = Date.now() + return next.run(input) + }) + ) + } + + this.lastCall = now + return next.run(input) + } +} + +// Limit to one request per second +agent.use(new RateLimitMiddleware(1000)) +``` + +### Retry Logic + +```typescript +const retryMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + retry({ + count: 3, + delay: (error, retryCount) => { + console.log(`Retry attempt ${retryCount}`) + return timer(1000 * retryCount) // Exponential backoff + } + }) + ) +} +``` + +### Caching + +```typescript +class CacheMiddleware extends Middleware { + private cache = new Map() + + run(input: RunAgentInput, next: AbstractAgent): Observable { + const cacheKey = this.getCacheKey(input) + + if (this.cache.has(cacheKey)) { + console.log("Cache hit") + return from(this.cache.get(cacheKey)!) + } + + const events: BaseEvent[] = [] + + return next.run(input).pipe( + tap(event => events.push(event)), + finalize(() => { + this.cache.set(cacheKey, events) + }) + ) + } + + private getCacheKey(input: RunAgentInput): string { + // Create a cache key from the input + return JSON.stringify({ + messages: input.messages, + tools: input.tools.map(t => t.name) + }) + } +} +``` + +## Chaining Middleware + +Multiple middleware can be combined to create sophisticated processing pipelines. + +```typescript +// Create middleware instances +const logger = loggingMiddleware +const auth = new AuthMiddleware(apiKey) +const filter = new FilterToolCallsMiddleware({ + allowedToolCalls: ["search"] +}) +const rateLimit = new RateLimitMiddleware(1000) + +// Apply middleware in order +agent.use( + logger, // First: Log all events + auth, // Second: Add authentication + rateLimit, // Third: Apply rate limiting + filter // Fourth: Filter tool calls +) + +// Execution flow: +// logger → auth → rateLimit → filter → agent → filter → rateLimit → auth → logger +``` + +## Advanced Usage + +### Conditional Middleware + +```typescript +const debugMiddleware: MiddlewareFunction = (input, next) => { + const isDebug = input.context.some(c => c.type === "debug") + + if (!isDebug) { + return next.run(input) + } + + return next.run(input).pipe( + tap(event => { + console.debug("[DEBUG]", JSON.stringify(event, null, 2)) + }) + ) +} +``` + +### Event Filtering + +```typescript +const filterEventsMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + filter(event => { + // Only allow specific event types + return [ + EventType.RUN_STARTED, + EventType.TEXT_MESSAGE_CHUNK, + EventType.RUN_FINISHED + ].includes(event.type) + }) + ) +} +``` + +### Stream Manipulation + +```typescript +const bufferMiddleware: MiddlewareFunction = (input, next) => { + return next.run(input).pipe( + // Buffer text chunks and emit them in batches + bufferWhen(() => + interval(100).pipe( + filter(() => true) + ) + ), + map(events => events.flat()) + ) +} +``` + +## Best Practices + +1. **Single Responsibility**: Each middleware should focus on one concern +2. **Error Handling**: Always handle errors gracefully and consider recovery strategies +3. **Performance**: Be mindful of processing overhead in high-throughput scenarios +4. **State Management**: Use class-based middleware when state is required +5. **Testing**: Write unit tests for each middleware independently +6. **Documentation**: Document middleware behavior and side effects + +## TypeScript Support + +The middleware system is fully typed for excellent IDE support: + +```typescript +import { + Middleware, + MiddlewareFunction, + FilterToolCallsMiddleware +} from "@ag-ui/client" +import { RunAgentInput, BaseEvent, EventType } from "@ag-ui/core" + +// Type-safe middleware function +const typedMiddleware: MiddlewareFunction = ( + input: RunAgentInput, + next: AbstractAgent +): Observable => { + return next.run(input) +} + +// Type-safe middleware class +class TypedMiddleware extends Middleware { + run( + input: RunAgentInput, + next: AbstractAgent + ): Observable { + return next.run(input) + } +} +``` \ No newline at end of file diff --git a/docs/sdk/js/client/overview.mdx b/docs/sdk/js/client/overview.mdx index 0bfcf085f..3902715c5 100644 --- a/docs/sdk/js/client/overview.mdx +++ b/docs/sdk/js/client/overview.mdx @@ -61,6 +61,30 @@ Concrete implementation for HTTP-based agent connectivity: efficient event encoding format +## Middleware + +Transform and intercept event streams flowing through agents with a flexible +middleware system: + +- [Function Middleware](/sdk/js/client/middleware#function-based-middleware) - Simple + transformations with plain functions +- [Class Middleware](/sdk/js/client/middleware#class-based-middleware) - Stateful + middleware with configuration +- [Built-in Middleware](/sdk/js/client/middleware#built-in-middleware) - + FilterToolCallsMiddleware and more +- [Middleware Patterns](/sdk/js/client/middleware#middleware-patterns) - Common + use cases and examples + + + Powerful event stream transformation and filtering for AG-UI agents + + ## AgentSubscriber Event-driven subscriber system for handling agent lifecycle events and state diff --git a/sdks/typescript/packages/client/README.md b/sdks/typescript/packages/client/README.md index 1be36135a..fbc9c41db 100644 --- a/sdks/typescript/packages/client/README.md +++ b/sdks/typescript/packages/client/README.md @@ -19,6 +19,7 @@ yarn add @ag-ui/client - 📡 **Event streaming** – Full AG-UI event processing with validation and transformation - 🔄 **State management** – Automatic message/state tracking with reactive updates - 🪝 **Subscriber system** – Middleware-style hooks for logging, persistence, and custom logic +- 🎯 **Middleware support** – Transform and filter events with function or class-based middleware ## Quick example @@ -37,6 +38,32 @@ const result = await agent.runAgent({ console.log(result.newMessages); ``` +## Using Middleware + +```ts +import { HttpAgent, FilterToolCallsMiddleware } from "@ag-ui/client"; + +const agent = new HttpAgent({ + url: "https://api.example.com/agent", +}); + +// Add middleware to transform or filter events +agent.use( + // Function middleware for logging + (input, next) => { + console.log("Starting run:", input.runId); + return next.run(input); + }, + + // Class middleware for filtering tool calls + new FilterToolCallsMiddleware({ + allowedToolCalls: ["search", "calculate"] + }) +); + +await agent.runAgent(); +``` + ## Documentation - Concepts & architecture: [`docs/concepts`](https://docs.ag-ui.com/concepts/architecture) diff --git a/sdks/typescript/packages/client/src/agent/agent.ts b/sdks/typescript/packages/client/src/agent/agent.ts index 8e2f95610..acfabc64c 100644 --- a/sdks/typescript/packages/client/src/agent/agent.ts +++ b/sdks/typescript/packages/client/src/agent/agent.ts @@ -13,6 +13,7 @@ import { LegacyRuntimeProtocolEvent } from "@/legacy/types"; import { lastValueFrom } from "rxjs"; import { transformChunks } from "@/chunks"; import { AgentStateMutation, AgentSubscriber, runSubscribersWithMutation } from "./subscriber"; +import { Middleware, MiddlewareFunction, FunctionMiddleware } from "@/middleware"; export interface RunAgentResult { result: any; @@ -27,6 +28,9 @@ export abstract class AbstractAgent { public state: State; public debug: boolean = false; public subscribers: AgentSubscriber[] = []; + private middlewares: Middleware[] = []; + + public readonly maxVersion: string = "*"; constructor({ agentId, @@ -53,6 +57,14 @@ export abstract class AbstractAgent { }; } + public use(...middlewares: (Middleware | MiddlewareFunction)[]): this { + const normalizedMiddlewares = middlewares.map((middleware) => + typeof middleware === "function" ? new FunctionMiddleware(middleware) : middleware, + ); + this.middlewares.push(...normalizedMiddlewares); + return this; + } + abstract run(input: RunAgentInput): Observable; public async runAgent( @@ -77,7 +89,22 @@ export abstract class AbstractAgent { await this.onInitialize(input, subscribers); const pipeline = pipe( - () => this.run(input), + () => { + // Build middleware chain using reduceRight + if (this.middlewares.length === 0) { + return this.run(input); + } + + const chainedAgent = this.middlewares.reduceRight( + (nextAgent: AbstractAgent, middleware) => + ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + }) as AbstractAgent, + this, // Original agent is the final 'next' + ); + + return chainedAgent.run(input); + }, transformChunks(this.debug), verifyEvents(this.debug), (source$) => this.apply(input, source$, subscribers), @@ -416,7 +443,24 @@ export abstract class AbstractAgent { this.agentId = this.agentId ?? uuidv4(); const input = this.prepareRunAgentInput(config); - return this.run(input).pipe( + // Build middleware chain for legacy bridge + const runObservable = (() => { + if (this.middlewares.length === 0) { + return this.run(input); + } + + const chainedAgent = this.middlewares.reduceRight( + (nextAgent: AbstractAgent, middleware) => + ({ + run: (i: RunAgentInput) => middleware.run(i, nextAgent), + }) as AbstractAgent, + this, + ); + + return chainedAgent.run(input); + })(); + + return runObservable.pipe( transformChunks(this.debug), verifyEvents(this.debug), convertToLegacyEvents(this.threadId, input.runId, this.agentId), diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts new file mode 100644 index 000000000..978bbc80c --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/filter-tool-calls.test.ts @@ -0,0 +1,184 @@ +import { AbstractAgent } from "@/agent"; +import { FilterToolCallsMiddleware } from "@/middleware/filter-tool-calls"; +import { Middleware } from "@/middleware"; +import { + BaseEvent, + EventType, + RunAgentInput, + ToolCallStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, +} from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("FilterToolCallsMiddleware", () => { + class ToolCallingAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + // Emit RUN_STARTED + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + // Emit first tool call (calculator) + const toolCall1Id = "tool-call-1"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall1Id, + toolCallName: "calculator", + parentMessageId: "message-1", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall1Id, + delta: '{"operation": "add", "a": 5, "b": 3}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall1Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-1", + toolCallId: toolCall1Id, + content: "8", + } as ToolCallResultEvent); + + // Emit second tool call (weather) + const toolCall2Id = "tool-call-2"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall2Id, + toolCallName: "weather", + parentMessageId: "message-2", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall2Id, + delta: '{"city": "New York"}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall2Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-2", + toolCallId: toolCall2Id, + content: "Sunny, 72°F", + } as ToolCallResultEvent); + + // Emit third tool call (search) + const toolCall3Id = "tool-call-3"; + subscriber.next({ + type: EventType.TOOL_CALL_START, + toolCallId: toolCall3Id, + toolCallName: "search", + parentMessageId: "message-3", + } as ToolCallStartEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_ARGS, + toolCallId: toolCall3Id, + delta: '{"query": "TypeScript middleware"}', + } as ToolCallArgsEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_END, + toolCallId: toolCall3Id, + } as ToolCallEndEvent); + + subscriber.next({ + type: EventType.TOOL_CALL_RESULT, + messageId: "tool-message-3", + toolCallId: toolCall3Id, + content: "Results found...", + } as ToolCallResultEvent); + + // Emit RUN_FINISHED + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should filter out disallowed tool calls", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + disallowedToolCalls: ["calculator", "search"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should have RUN_STARTED, weather tool events (4), and RUN_FINISHED + expect(events.length).toBe(6); + + // Check that we have RUN_STARTED + expect(events[0].type).toBe(EventType.RUN_STARTED); + + // Check that only weather tool calls are present + const toolCallStarts = events.filter((e) => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("weather"); + + // Check that calculator and search are filtered out + const allToolNames = toolCallStarts.map((e) => e.toolCallName); + expect(allToolNames).not.toContain("calculator"); + expect(allToolNames).not.toContain("search"); + + // Check that we have RUN_FINISHED + expect(events[events.length - 1].type).toBe(EventType.RUN_FINISHED); + }); + + it("should allow only allowed tool calls when using allowlist", async () => { + const agent = new ToolCallingAgent(); + const middleware = new FilterToolCallsMiddleware({ + allowedToolCalls: ["calculator"], + }); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + // Should have RUN_STARTED, calculator tool events (4), and RUN_FINISHED + expect(events.length).toBe(6); + + const toolCallStarts = events.filter((e) => e.type === EventType.TOOL_CALL_START) as ToolCallStartEvent[]; + expect(toolCallStarts.length).toBe(1); + expect(toolCallStarts[0].toolCallName).toBe("calculator"); + }); +}); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts new file mode 100644 index 000000000..11f94035e --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/function-middleware.test.ts @@ -0,0 +1,86 @@ +import { AbstractAgent } from "@/agent"; +import { FunctionMiddleware, MiddlewareFunction } from "@/middleware"; +import { BaseEvent, EventType, RunAgentInput } from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("FunctionMiddleware", () => { + class TestAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.complete(); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should allow function-based middleware to intercept events", async () => { + const agent = new TestAgent(); + + const middlewareFn: MiddlewareFunction = (middlewareInput, next) => { + return new Observable((subscriber) => { + const subscription = next.run(middlewareInput).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_STARTED) { + subscriber.next({ + ...event, + metadata: { ...(event as any).metadata, fromMiddleware: true }, + }); + return; + } + + if (event.type === EventType.RUN_FINISHED) { + subscriber.next({ + ...event, + result: { success: true }, + }); + return; + } + + subscriber.next(event); + }, + error: (error) => subscriber.error(error), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + }; + + const middleware = new FunctionMiddleware(middlewareFn); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + expect(events.length).toBe(2); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect((events[0] as any).metadata).toEqual({ fromMiddleware: true }); + expect(events[1].type).toBe(EventType.RUN_FINISHED); + expect((events[1] as any).result).toEqual({ success: true }); + }); +}); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts new file mode 100644 index 000000000..bedb8f8c3 --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-live-events.test.ts @@ -0,0 +1,98 @@ +import { AbstractAgent } from "@/agent"; +import { Middleware } from "@/middleware"; +import { + BaseEvent, + EventType, + RunAgentInput, + TextMessageChunkEvent, + RunFinishedEvent, + RunStartedEvent, +} from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("Middleware live events", () => { + class LiveEventAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + } as RunStartedEvent); + + subscriber.next({ + type: EventType.TEXT_MESSAGE_CHUNK, + messageId: "message-1", + role: "assistant", + delta: "Hello", + } as TextMessageChunkEvent); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + result: { success: true }, + } as RunFinishedEvent); + + subscriber.complete(); + }); + } + } + + class CustomMiddleware extends Middleware { + run(input: RunAgentInput, next: AbstractAgent): Observable { + return new Observable((subscriber) => { + const subscription = next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_STARTED) { + const started = event as RunStartedEvent; + subscriber.next({ + ...started, + metadata: { + ...(started.metadata ?? {}), + custom: true, + }, + }); + return; + } + + subscriber.next(event); + }, + error: (error) => subscriber.error(error), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should allow middleware to emit events before the agent", async () => { + const agent = new LiveEventAgent(); + const middleware = new CustomMiddleware(); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + expect(events.length).toBe(3); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect((events[0] as RunStartedEvent).metadata).toEqual({ custom: true }); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_CHUNK); + expect(events[2].type).toBe(EventType.RUN_FINISHED); + }); +}); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-usage-example.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-usage-example.ts new file mode 100644 index 000000000..209c5dc12 --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-usage-example.ts @@ -0,0 +1,130 @@ +import { AbstractAgent } from "@/agent"; +import { + Middleware, + FunctionMiddleware, + MiddlewareFunction, + FilterToolCallsMiddleware, +} from "@/middleware"; +import { + BaseEvent, + EventType, + RunAgentInput, + TextMessageChunkEvent, + RunFinishedEvent, + RunStartedEvent, +} from "@ag-ui/core"; +import { Observable } from "rxjs"; + +/** + * Example agent that emits a simple conversation flow. + */ +class ExampleAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + } as RunStartedEvent); + + subscriber.next({ + type: EventType.TEXT_MESSAGE_CHUNK, + messageId: "message-1", + role: "assistant", + delta: "Hello! Let me calculate that for you.", + } as TextMessageChunkEvent); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + result: { answer: 42 }, + } as RunFinishedEvent); + + subscriber.complete(); + }); + } +} + +/** + * Example middleware that logs events as they pass through. + */ +class LoggingMiddleware extends Middleware { + run(input: RunAgentInput, next: AbstractAgent): Observable { + console.log("Middleware input:", input); + + return next.run(input); + } +} + +/** + * Example function-based middleware that modifies the result. + */ +const resultEnhancer: MiddlewareFunction = (input, next) => { + return new Observable((subscriber) => { + next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_FINISHED) { + subscriber.next({ + ...event, + result: { + ...(event as RunFinishedEvent).result, + enhanced: true, + }, + }); + } else { + subscriber.next(event); + } + }, + error: (error) => subscriber.error(error), + complete: () => subscriber.complete(), + }); + }); +}; + +const input: RunAgentInput = { + threadId: "example-thread", + runId: "example-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], +}; + +/** + * Example usage demonstrating middleware chaining. + */ +async function runExample() { + const agent = new ExampleAgent(); + + // Function-based middleware + agent.use(new FunctionMiddleware(resultEnhancer)); + + // Class-based middleware + agent.use(new LoggingMiddleware()); + + // Built-in middleware to filter tool calls + agent.use(new FilterToolCallsMiddleware({ disallowedToolCalls: ["calculator"] })); + + const events: BaseEvent[] = []; + await new Promise((resolve, reject) => { + agent.runAgent({}, { + onRunFinalized: ({ messages }) => { + console.log("Final messages:", messages); + }, + onRunFinishedEvent: ({ result }) => { + console.log("Run finished result:", result); + }, + }).then(({ newMessages, result }) => { + console.log("New messages:", newMessages); + console.log("Final result:", result); + resolve(); + }).catch(reject); + }); + + return events; +} + +// eslint-disable-next-line @typescript-eslint/no-floating-promises +runExample(); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts new file mode 100644 index 000000000..6ac9573c8 --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware-with-state.test.ts @@ -0,0 +1,83 @@ +import { AbstractAgent } from "@/agent"; +import { Middleware } from "@/middleware"; +import { + BaseEvent, + EventType, + RunAgentInput, + RunFinishedEvent, + TextMessageChunkEvent, +} from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("Middleware runNextWithState", () => { + class StatefulAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.next({ + type: EventType.TEXT_MESSAGE_CHUNK, + messageId: "message-1", + role: "assistant", + delta: "Hello", + } as TextMessageChunkEvent); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + result: { success: true }, + } as RunFinishedEvent); + + subscriber.complete(); + }); + } + } + + class StateTrackingMiddleware extends Middleware { + run(input: RunAgentInput, next: AbstractAgent): Observable { + return this.runNextWithState(input, next).pipe((source) => { + return new Observable((subscriber) => { + source.subscribe({ + next: ({ event }) => subscriber.next(event), + complete: () => subscriber.complete(), + }); + }); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should capture state changes after each event", async () => { + const agent = new StatefulAgent(); + const middleware = new StateTrackingMiddleware(); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + expect(events.length).toBe(5); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect(events[1].type).toBe(EventType.TEXT_MESSAGE_START); + expect(events[2].type).toBe(EventType.TEXT_MESSAGE_CONTENT); + expect(events[3].type).toBe(EventType.TEXT_MESSAGE_END); + expect(events[4].type).toBe(EventType.RUN_FINISHED); + }); +}); diff --git a/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts b/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts new file mode 100644 index 000000000..a07cfb870 --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/__tests__/middleware.test.ts @@ -0,0 +1,80 @@ +import { AbstractAgent } from "@/agent"; +import { Middleware } from "@/middleware"; +import { BaseEvent, EventType, RunAgentInput } from "@ag-ui/core"; +import { Observable } from "rxjs"; + +describe("Middleware", () => { + class TestAgent extends AbstractAgent { + run(input: RunAgentInput): Observable { + return new Observable((subscriber) => { + subscriber.next({ + type: EventType.RUN_STARTED, + threadId: input.threadId, + runId: input.runId, + }); + + subscriber.next({ + type: EventType.RUN_FINISHED, + threadId: input.threadId, + runId: input.runId, + result: { success: true }, + }); + + subscriber.complete(); + }); + } + } + + class TestMiddleware extends Middleware { + run(input: RunAgentInput, next: AbstractAgent): Observable { + return new Observable((subscriber) => { + const subscription = next.run(input).subscribe({ + next: (event) => { + if (event.type === EventType.RUN_STARTED) { + subscriber.next({ + ...event, + metadata: { ...(event as any).metadata, middleware: true }, + }); + return; + } + + subscriber.next(event); + }, + error: (error) => subscriber.error(error), + complete: () => subscriber.complete(), + }); + + return () => subscription.unsubscribe(); + }); + } + } + + const input: RunAgentInput = { + threadId: "test-thread", + runId: "test-run", + tools: [], + context: [], + forwardedProps: {}, + state: {}, + messages: [], + }; + + it("should allow middleware to modify the event stream", async () => { + const agent = new TestAgent(); + const middleware = new TestMiddleware(); + + const events: BaseEvent[] = []; + await new Promise((resolve) => { + middleware.run(input, agent).subscribe({ + next: (event) => events.push(event), + complete: () => resolve(), + }); + }); + + expect(events.length).toBe(2); + expect(events[0].type).toBe(EventType.RUN_STARTED); + expect((events[0] as any).metadata).toEqual({ middleware: true }); + expect(events[1].type).toBe(EventType.RUN_FINISHED); + expect((events[1] as any).result).toEqual({ success: true }); + }); +}); diff --git a/sdks/typescript/packages/client/src/middleware/filter-tool-calls.ts b/sdks/typescript/packages/client/src/middleware/filter-tool-calls.ts new file mode 100644 index 000000000..157ff6c3d --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/filter-tool-calls.ts @@ -0,0 +1,104 @@ +import { Middleware } from "./middleware"; +import { AbstractAgent } from "@/agent"; +import { + RunAgentInput, + BaseEvent, + EventType, + ToolCallStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, +} from "@ag-ui/core"; +import { Observable } from "rxjs"; +import { filter } from "rxjs/operators"; + +type FilterToolCallsConfig = + | { allowedToolCalls: string[]; disallowedToolCalls?: never } + | { disallowedToolCalls: string[]; allowedToolCalls?: never }; + +export class FilterToolCallsMiddleware extends Middleware { + private blockedToolCallIds = new Set(); + private readonly allowedTools?: Set; + private readonly disallowedTools?: Set; + + constructor(config: FilterToolCallsConfig) { + super(); + + // Runtime validation (belt and suspenders approach) + if (config.allowedToolCalls && config.disallowedToolCalls) { + throw new Error("Cannot specify both allowedToolCalls and disallowedToolCalls"); + } + + if (!config.allowedToolCalls && !config.disallowedToolCalls) { + throw new Error("Must specify either allowedToolCalls or disallowedToolCalls"); + } + + if (config.allowedToolCalls) { + this.allowedTools = new Set(config.allowedToolCalls); + } else if (config.disallowedToolCalls) { + this.disallowedTools = new Set(config.disallowedToolCalls); + } + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + // Use runNext which already includes transformChunks + return this.runNext(input, next).pipe( + filter((event) => { + // Handle TOOL_CALL_START events + if (event.type === EventType.TOOL_CALL_START) { + const toolCallStartEvent = event as ToolCallStartEvent; + const shouldFilter = this.shouldFilterTool(toolCallStartEvent.toolCallName); + + if (shouldFilter) { + // Track this tool call ID as blocked + this.blockedToolCallIds.add(toolCallStartEvent.toolCallId); + return false; // Filter out this event + } + + return true; // Allow this event + } + + // Handle TOOL_CALL_ARGS events + if (event.type === EventType.TOOL_CALL_ARGS) { + const toolCallArgsEvent = event as ToolCallArgsEvent; + return !this.blockedToolCallIds.has(toolCallArgsEvent.toolCallId); + } + + // Handle TOOL_CALL_END events + if (event.type === EventType.TOOL_CALL_END) { + const toolCallEndEvent = event as ToolCallEndEvent; + return !this.blockedToolCallIds.has(toolCallEndEvent.toolCallId); + } + + // Handle TOOL_CALL_RESULT events + if (event.type === EventType.TOOL_CALL_RESULT) { + const toolCallResultEvent = event as ToolCallResultEvent; + const isBlocked = this.blockedToolCallIds.has(toolCallResultEvent.toolCallId); + + if (isBlocked) { + // Clean up the blocked ID after the last event + this.blockedToolCallIds.delete(toolCallResultEvent.toolCallId); + return false; + } + + return true; + } + + // Allow all other events through + return true; + }), + ); + } + + private shouldFilterTool(toolName: string): boolean { + if (this.allowedTools) { + // If using allowed list, filter out tools NOT in the list + return !this.allowedTools.has(toolName); + } else if (this.disallowedTools) { + // If using disallowed list, filter out tools IN the list + return this.disallowedTools.has(toolName); + } + + return false; + } +} diff --git a/sdks/typescript/packages/client/src/middleware/index.ts b/sdks/typescript/packages/client/src/middleware/index.ts new file mode 100644 index 000000000..d60de5e3c --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/index.ts @@ -0,0 +1,3 @@ +export { Middleware, FunctionMiddleware } from "./middleware"; +export type { MiddlewareFunction } from "./middleware"; +export { FilterToolCallsMiddleware } from "./filter-tool-calls"; diff --git a/sdks/typescript/packages/client/src/middleware/middleware.ts b/sdks/typescript/packages/client/src/middleware/middleware.ts new file mode 100644 index 000000000..3d33c589a --- /dev/null +++ b/sdks/typescript/packages/client/src/middleware/middleware.ts @@ -0,0 +1,87 @@ +import { AbstractAgent } from "@/agent"; +import { RunAgentInput, BaseEvent, Message } from "@ag-ui/core"; +import { Observable, ReplaySubject } from "rxjs"; +import { concatMap } from "rxjs/operators"; +import { transformChunks } from "@/chunks"; +import { defaultApplyEvents } from "@/apply"; +import { structuredClone_ } from "@/utils"; + +export type MiddlewareFunction = ( + input: RunAgentInput, + next: AbstractAgent, +) => Observable; + +export interface EventWithState { + event: BaseEvent; + messages: Message[]; + state: any; +} + +export abstract class Middleware { + abstract run(input: RunAgentInput, next: AbstractAgent): Observable; + + /** + * Runs the next agent in the chain with automatic chunk transformation. + */ + protected runNext(input: RunAgentInput, next: AbstractAgent): Observable { + return next.run(input).pipe( + transformChunks(false), // Always transform chunks to full events + ); + } + + /** + * Runs the next agent and tracks state, providing current messages and state with each event. + * The messages and state represent the state AFTER the event has been applied. + */ + protected runNextWithState( + input: RunAgentInput, + next: AbstractAgent, + ): Observable { + let currentMessages = structuredClone_(input.messages || []); + let currentState = structuredClone_(input.state || {}); + + // Use a ReplaySubject to feed events one by one + const eventSubject = new ReplaySubject(); + + // Set up defaultApplyEvents to process events + const mutations$ = defaultApplyEvents(input, eventSubject, next, []); + + // Subscribe to track state changes + mutations$.subscribe((mutation) => { + if (mutation.messages !== undefined) { + currentMessages = mutation.messages; + } + if (mutation.state !== undefined) { + currentState = mutation.state; + } + }); + + return this.runNext(input, next).pipe( + concatMap(async (event) => { + // Feed the event to defaultApplyEvents and wait for it to process + eventSubject.next(event); + + // Give defaultApplyEvents a chance to process + await new Promise((resolve) => setTimeout(resolve, 0)); + + // Return event with current state + return { + event, + messages: structuredClone_(currentMessages), + state: structuredClone_(currentState), + }; + }), + ); + } +} + +// Wrapper class to convert a function into a Middleware instance +export class FunctionMiddleware extends Middleware { + constructor(private fn: MiddlewareFunction) { + super(); + } + + run(input: RunAgentInput, next: AbstractAgent): Observable { + return this.fn(input, next); + } +}