Skip to content

Commit caa52cf

Browse files
committed
Ephemeral context
1 parent 72beb14 commit caa52cf

File tree

7 files changed

+84
-63
lines changed

7 files changed

+84
-63
lines changed

src/agent.ts

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import {
1616
MagmaSystemMessage,
1717
MagmaToolResultMessage,
1818
MagmaMessage,
19+
DecoratedExtras,
1920
} from './types';
2021
import {
2122
loadHooks,
@@ -59,13 +60,11 @@ type AgentProps = {
5960
general?: CallSettings;
6061
verbose?: boolean;
6162
messageContext?: number;
62-
stream?: boolean;
6363
sessionId?: string;
6464
};
6565

6666
export class MagmaAgent {
6767
verbose?: boolean;
68-
stream: boolean = false;
6968
public sessionId: string;
7069

7170
private provider: OpenRouterProvider = createOpenRouter();
@@ -84,7 +83,6 @@ export class MagmaAgent {
8483
constructor(args?: AgentProps) {
8584
this.messageContext = args?.messageContext ?? 20;
8685
this.verbose = args?.verbose ?? false;
87-
this.stream = args?.stream ?? false;
8886
this.sessionId = args?.sessionId ?? '';
8987

9088
args ??= {
@@ -146,6 +144,8 @@ export class MagmaAgent {
146144
userMessage: MagmaUserMessage;
147145
onTrace?: (trace: TraceEvent[]) => void;
148146
trigger?: false;
147+
ctx?: Record<string, any>;
148+
onStreamChunk?: (chunk: MagmaStreamChunk, extras: DecoratedExtras) => void;
149149
}): Promise<MagmaAssistantMessage | null>;
150150

151151
public async main(args: {
@@ -157,6 +157,8 @@ export class MagmaAgent {
157157
userMessage: MagmaUserMessage;
158158
onTrace?: (trace: TraceEvent[]) => void;
159159
trigger: true;
160+
ctx?: Record<string, any>;
161+
onStreamChunk?: (chunk: MagmaStreamChunk, extras: DecoratedExtras) => void;
160162
}): Promise<MagmaToolResultMessage | null>;
161163

162164
public async main(args: {
@@ -168,13 +170,16 @@ export class MagmaAgent {
168170
userMessage: MagmaUserMessage;
169171
onTrace?: (trace: TraceEvent[]) => void;
170172
trigger?: boolean;
173+
ctx?: Record<string, any>;
174+
onStreamChunk?: (chunk: MagmaStreamChunk, extras: DecoratedExtras) => void;
171175
}): Promise<MagmaAssistantMessage | MagmaToolResultMessage | null> {
172176
return (await this._main({
173177
config: args.config,
174178
send: args.send,
175179
userOrToolMessage: args.userMessage,
176180
onTrace: args.onTrace,
177181
trigger: args.trigger,
182+
ctx: args.ctx,
178183
})) as MagmaToolResultMessage | null;
179184
}
180185

@@ -190,6 +195,8 @@ export class MagmaAgent {
190195
messages?: Array<ModelMessage>;
191196
trace?: TraceEvent[];
192197
onTrace?: (trace: TraceEvent[]) => void;
198+
ctx?: Record<string, any>;
199+
onStreamChunk?: (chunk: MagmaStreamChunk, extras: DecoratedExtras) => void;
193200
}): Promise<AssistantModelMessage | MagmaToolResultMessage | null> {
194201
const {
195202
config,
@@ -198,6 +205,8 @@ export class MagmaAgent {
198205
send = () => {},
199206
userOrToolMessage,
200207
trigger = false,
208+
ctx = {},
209+
onStreamChunk = () => {},
201210
} = args;
202211
let originIndex = args.originIndex;
203212
const localMessages = [
@@ -234,6 +243,7 @@ export class MagmaAgent {
234243
trace,
235244
requestId,
236245
send,
246+
ctx,
237247
});
238248
} catch (error) {
239249
// If the preCompletion middleware fails, we should remove the last message
@@ -309,7 +319,7 @@ export class MagmaAgent {
309319
.filter((t) => t.enabled(this))
310320
.map((t) => [t.name, convertMagmaToolToAISDKTool(t)])
311321
),
312-
messages: [...this.getSystemPrompts(), ...completionMessages],
322+
messages: [...this.getSystemPrompts(ctx), ...completionMessages],
313323
abortSignal: this.abortControllers.get(requestId)?.signal,
314324
...configToUse.general,
315325
});
@@ -320,9 +330,7 @@ export class MagmaAgent {
320330
}
321331

322332
for await (const chunk of fullStream) {
323-
if (this.stream) {
324-
this.onStreamChunk(chunk, send);
325-
}
333+
onStreamChunk(chunk, { agent: this, send, ctx });
326334
}
327335

328336
// create the Asisstant message from the completion
@@ -349,6 +357,7 @@ export class MagmaAgent {
349357
trace,
350358
requestId,
351359
send,
360+
ctx,
352361
});
353362
} catch (error) {
354363
// If the onCompletion middleware fails, we should remove the last message
@@ -407,6 +416,7 @@ export class MagmaAgent {
407416
trace,
408417
requestId,
409418
send,
419+
ctx,
410420
});
411421
} catch (error) {
412422
// Remove the failing tool call message
@@ -450,6 +460,7 @@ export class MagmaAgent {
450460
trace,
451461
requestId,
452462
send,
463+
ctx,
453464
});
454465

455466
onToolExecutionMiddlewareResult =
@@ -458,6 +469,7 @@ export class MagmaAgent {
458469
trace,
459470
requestId,
460471
send,
472+
ctx,
461473
});
462474

463475
// If the abort controller is not active, return null
@@ -493,6 +505,7 @@ export class MagmaAgent {
493505
trace,
494506
requestId,
495507
send,
508+
ctx,
496509
});
497510
} catch (error) {
498511
// If the onMainFinish middleware fails, we should remove the offending message
@@ -614,12 +627,14 @@ export class MagmaAgent {
614627
trace,
615628
requestId,
616629
send,
630+
ctx,
617631
}: {
618632
message: AssistantModelMessage;
619633
allowList?: string[];
620634
trace: TraceEvent[];
621635
requestId: string;
622636
send: MagmaSendFunction;
637+
ctx: Record<string, any>;
623638
}): Promise<ToolModelMessage> {
624639
try {
625640
let toolResultMessage: ToolModelMessage = {
@@ -655,7 +670,7 @@ export class MagmaAgent {
655670
},
656671
});
657672

658-
let result = await tool.target(toolCall, send, this);
673+
let result = await tool.target(toolCall, { agent: this, send, ctx });
659674

660675
if (!result) {
661676
this.log(`No result returned for ${toolCall.toolName}()`);
@@ -747,11 +762,13 @@ export class MagmaAgent {
747762
trace,
748763
requestId,
749764
send,
765+
ctx,
750766
}: {
751767
message: UserModelMessage;
752768
trace: TraceEvent[];
753769
requestId: string;
754770
send: MagmaSendFunction;
771+
ctx: Record<string, any>;
755772
}): Promise<UserModelMessage> {
756773
// get preCompletion middleware
757774
const preCompletionMiddleware = this.middleware.filter(
@@ -788,11 +805,11 @@ export class MagmaAgent {
788805
},
789806
});
790807
// run the middleware on the text block
791-
const middlewareResult = (await mdlwr.action(
792-
textBlock.text,
808+
const middlewareResult = (await mdlwr.action(textBlock.text, {
809+
agent: this,
793810
send,
794-
this
795-
)) as string;
811+
ctx,
812+
})) as string;
796813
// if the middleware has a return value, we should update the text block in the result message
797814
if (middlewareResult !== undefined) {
798815
this.log(
@@ -854,11 +871,13 @@ export class MagmaAgent {
854871
trace,
855872
requestId,
856873
send,
874+
ctx,
857875
}: {
858876
message: AssistantModelMessage;
859877
trace: TraceEvent[];
860878
requestId: string;
861879
send: MagmaSendFunction;
880+
ctx: Record<string, any>;
862881
}): Promise<AssistantModelMessage | null> {
863882
// get onCompletion middleware
864883
const onCompletionMiddleware = this.middleware.filter((f) => f.trigger === 'onCompletion');
@@ -895,11 +914,11 @@ export class MagmaAgent {
895914
},
896915
});
897916
// run the middleware on the text block
898-
const middlewareResult = (await mdlwr.action(
899-
textBlock.text,
917+
const middlewareResult = (await mdlwr.action(textBlock.text, {
918+
agent: this,
900919
send,
901-
this
902-
)) as string;
920+
ctx,
921+
})) as string;
903922
// if the middleware has a return value, we should update the text block in the result message
904923
if (middlewareResult !== undefined) {
905924
this.log(
@@ -979,11 +998,13 @@ export class MagmaAgent {
979998
trace,
980999
requestId,
9811000
send,
1001+
ctx,
9821002
}: {
9831003
message: AssistantModelMessage;
9841004
trace: TraceEvent[];
9851005
requestId: string;
9861006
send: MagmaSendFunction;
1007+
ctx: Record<string, any>;
9871008
}): Promise<AssistantModelMessage | null> {
9881009
// get onMainFinish middleware
9891010
const onMainFinishMiddleware = this.middleware.filter((f) => f.trigger === 'onMainFinish');
@@ -1020,11 +1041,11 @@ export class MagmaAgent {
10201041
},
10211042
});
10221043
// run the middleware on the text block
1023-
const middlewareResult = (await mdlwr.action(
1024-
textBlock.text,
1044+
const middlewareResult = (await mdlwr.action(textBlock.text, {
1045+
agent: this,
10251046
send,
1026-
this
1027-
)) as string;
1047+
ctx,
1048+
})) as string;
10281049
// if the middleware has a return value, we should update the text block in the result message
10291050
if (middlewareResult !== undefined) {
10301051
this.log(
@@ -1104,11 +1125,13 @@ export class MagmaAgent {
11041125
trace,
11051126
requestId,
11061127
send,
1128+
ctx,
11071129
}: {
11081130
message: AssistantModelMessage;
11091131
trace: TraceEvent[];
11101132
requestId: string;
11111133
send: MagmaSendFunction;
1134+
ctx: Record<string, any>;
11121135
}): Promise<AssistantModelMessage | null> {
11131136
// get preToolExecution middleware
11141137
const preToolExecutionMiddleware = this.middleware.filter(
@@ -1145,11 +1168,11 @@ export class MagmaAgent {
11451168
},
11461169
});
11471170
// run the middleware on the tool call
1148-
const middlewareResult = (await mdlwr.action(
1149-
toolCall,
1171+
const middlewareResult = (await mdlwr.action(toolCall, {
1172+
agent: this,
11501173
send,
1151-
this
1152-
)) as MagmaToolCall;
1174+
ctx,
1175+
})) as MagmaToolCall;
11531176
// if the middleware has a return value, we should update the tool call in the result message
11541177
if (middlewareResult !== undefined) {
11551178
this.log(
@@ -1229,11 +1252,13 @@ export class MagmaAgent {
12291252
trace,
12301253
requestId,
12311254
send,
1255+
ctx,
12321256
}: {
12331257
message: ToolModelMessage;
12341258
trace: TraceEvent[];
12351259
requestId: string;
12361260
send: MagmaSendFunction;
1261+
ctx: Record<string, any>;
12371262
}): Promise<ToolModelMessage> {
12381263
// get onToolExecution middleware
12391264
const onToolExecutionMiddleware = this.middleware.filter(
@@ -1266,11 +1291,11 @@ export class MagmaAgent {
12661291
},
12671292
});
12681293
// run the middleware on the tool result
1269-
const middlewareResult = (await mdlwr.action(
1270-
toolResult,
1294+
const middlewareResult = (await mdlwr.action(toolResult, {
1295+
agent: this,
12711296
send,
1272-
this
1273-
)) as MagmaToolResult;
1297+
ctx,
1298+
})) as MagmaToolResult;
12741299
// if the middleware has a return value, we should update the tool result in the result message
12751300
if (middlewareResult !== undefined) {
12761301
resultContent[i] = middlewareResult;
@@ -1400,7 +1425,7 @@ export class MagmaAgent {
14001425

14011426
/* EVENT HANDLERS */
14021427

1403-
getSystemPrompts(): MagmaSystemMessage[] {
1428+
getSystemPrompts(ctx: Record<string, any>): MagmaSystemMessage[] {
14041429
return [];
14051430
}
14061431

@@ -1409,11 +1434,6 @@ export class MagmaAgent {
14091434
throw error;
14101435
}
14111436

1412-
onStreamChunk(chunk: MagmaStreamChunk, send: MagmaSendFunction): Promise<void> | void {
1413-
chunk;
1414-
return;
1415-
}
1416-
14171437
onUsageUpdate(usage: MagmaUsage): Promise<void> | void {
14181438
usage;
14191439
return;

0 commit comments

Comments
 (0)