Skip to content

Commit 23f3c07

Browse files
committed
Add middleware API
1 parent 9a49a4e commit 23f3c07

File tree

12 files changed

+1699
-6
lines changed

12 files changed

+1699
-6
lines changed

typescript-sdk/integrations/mastra/src/mastra.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export class MastraAgent extends AbstractAgent {
5959
this.runtimeContext = runtimeContext ?? new RuntimeContext();
6060
}
6161

62-
protected run(input: RunAgentInput): Observable<BaseEvent> {
62+
public run(input: RunAgentInput): Observable<BaseEvent> {
6363
let messageId = randomUUID();
6464

6565
return new Observable<BaseEvent>((subscriber) => {

typescript-sdk/integrations/middleware-starter/src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { AbstractAgent, BaseEvent, EventType, RunAgentInput } from "@ag-ui/clien
22
import { Observable } from "rxjs";
33

44
export class MiddlewareStarterAgent extends AbstractAgent {
5-
protected run(input: RunAgentInput): Observable<BaseEvent> {
5+
public run(input: RunAgentInput): Observable<BaseEvent> {
66
const messageId = Date.now().toString();
77
return new Observable<BaseEvent>((observer) => {
88
observer.next({

typescript-sdk/integrations/vercel-ai-sdk/src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ export class VercelAISDKAgent extends AbstractAgent {
5555
this.toolChoice = toolChoice ?? "auto";
5656
}
5757

58-
protected run(input: RunAgentInput): Observable<BaseEvent> {
58+
public run(input: RunAgentInput): Observable<BaseEvent> {
5959
const finalMessages: Message[] = input.messages;
6060

6161
return new Observable<ProcessedEvent>((subscriber) => {

typescript-sdk/packages/client/src/agent/agent.ts

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { LegacyRuntimeProtocolEvent } from "@/legacy/types";
1313
import { lastValueFrom } from "rxjs";
1414
import { transformChunks } from "@/chunks";
1515
import { AgentStateMutation, AgentSubscriber, runSubscribersWithMutation } from "./subscriber";
16+
import { Middleware, MiddlewareFunction, FunctionMiddleware } from "@/middleware";
1617

1718
export interface RunAgentResult {
1819
result: any;
@@ -27,6 +28,7 @@ export abstract class AbstractAgent {
2728
public state: State;
2829
public debug: boolean = false;
2930
public subscribers: AgentSubscriber[] = [];
31+
private middlewares: Middleware[] = [];
3032

3133
constructor({
3234
agentId,
@@ -53,7 +55,15 @@ export abstract class AbstractAgent {
5355
};
5456
}
5557

56-
protected abstract run(input: RunAgentInput): Observable<BaseEvent>;
58+
public use(...middlewares: (Middleware | MiddlewareFunction)[]): this {
59+
const normalizedMiddlewares = middlewares.map(m =>
60+
typeof m === 'function' ? new FunctionMiddleware(m) : m
61+
);
62+
this.middlewares.push(...normalizedMiddlewares);
63+
return this;
64+
}
65+
66+
public abstract run(input: RunAgentInput): Observable<BaseEvent>;
5767

5868
public async runAgent(
5969
parameters?: RunAgentParameters,
@@ -77,7 +87,21 @@ export abstract class AbstractAgent {
7787
await this.onInitialize(input, subscribers);
7888

7989
const pipeline = pipe(
80-
() => this.run(input),
90+
() => {
91+
// Build middleware chain using reduceRight
92+
if (this.middlewares.length === 0) {
93+
return this.run(input);
94+
}
95+
96+
const chainedAgent = this.middlewares.reduceRight(
97+
(nextAgent: AbstractAgent, middleware) => ({
98+
run: (i: RunAgentInput) => middleware.run(i, nextAgent)
99+
} as AbstractAgent),
100+
this // Original agent is the final 'next'
101+
);
102+
103+
return chainedAgent.run(input);
104+
},
81105
transformChunks(this.debug),
82106
verifyEvents(this.debug),
83107
(source$) => this.apply(input, source$, subscribers),
@@ -416,7 +440,23 @@ export abstract class AbstractAgent {
416440
this.agentId = this.agentId ?? uuidv4();
417441
const input = this.prepareRunAgentInput(config);
418442

419-
return this.run(input).pipe(
443+
// Build middleware chain for legacy bridge
444+
const runObservable = (() => {
445+
if (this.middlewares.length === 0) {
446+
return this.run(input);
447+
}
448+
449+
const chainedAgent = this.middlewares.reduceRight(
450+
(nextAgent: AbstractAgent, middleware) => ({
451+
run: (i: RunAgentInput) => middleware.run(i, nextAgent)
452+
} as AbstractAgent),
453+
this
454+
);
455+
456+
return chainedAgent.run(input);
457+
})();
458+
459+
return runObservable.pipe(
420460
transformChunks(this.debug),
421461
verifyEvents(this.debug),
422462
convertToLegacyEvents(this.threadId, input.runId, this.agentId),

0 commit comments

Comments
 (0)