Skip to content

Commit 07b9121

Browse files
fix(summarization): resolve fraction trigger bug by using model profile for maxInputTokens (#189)
* fix(summarization): resolve fraction trigger bug by using model profile for maxInputTokens The fraction-based trigger in createSummarizationMiddleware was always ineffective because maxInputTokens was never passed to the internal functions that needed it. fixes #185 * fix type
1 parent d7cc7f4 commit 07b9121

File tree

2 files changed

+206
-28
lines changed

2 files changed

+206
-28
lines changed

libs/deepagents/src/middleware/summarization.test.ts

Lines changed: 128 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@ import type {
88
} from "../backends/protocol.js";
99
import { createMockBackend } from "./test.js";
1010

11-
// Mock the OpenAI module with a class constructor
12-
vi.mock("@langchain/openai", () => {
11+
// Mock the initChatModel function from langchain/chat_models/universal
12+
vi.mock("langchain/chat_models/universal", () => {
1313
return {
14-
ChatOpenAI: class MockChatOpenAI {
15-
constructor(_config: any) {}
16-
async invoke(_messages: any) {
17-
return {
18-
content: "This is a summary of the conversation.",
19-
};
20-
}
14+
initChatModel: async (_modelName: string) => {
15+
return {
16+
async invoke(_messages: any) {
17+
return {
18+
content: "This is a summary of the conversation.",
19+
};
20+
},
21+
// Mock profile with maxInputTokens for testing
22+
profile: {
23+
maxInputTokens: 128000,
24+
},
25+
};
2126
},
2227
};
2328
});
@@ -130,6 +135,120 @@ describe("createSummarizationMiddleware", () => {
130135
});
131136
});
132137

138+
describe("fraction trigger", () => {
139+
it("should trigger summarization when token count exceeds fraction of maxInputTokens", async () => {
140+
const mockBackend = createMockBackend();
141+
142+
// Create a mock model with profile containing low maxInputTokens
143+
const mockModelWithProfile = {
144+
profile: {
145+
maxInputTokens: 200, // Low threshold for testing (100 tokens = 50%)
146+
},
147+
async invoke(_messages: any) {
148+
return {
149+
content: "This is a summary of the conversation.",
150+
};
151+
},
152+
};
153+
154+
const middleware = createSummarizationMiddleware({
155+
model: mockModelWithProfile as any,
156+
backend: mockBackend,
157+
trigger: { type: "fraction", value: 0.5 }, // 50% of maxInputTokens
158+
keep: { type: "messages", value: 2 },
159+
});
160+
161+
// Create messages with enough content to exceed 100 tokens (50% of 200)
162+
const messages = Array.from(
163+
{ length: 10 },
164+
(_, i) =>
165+
new HumanMessage({
166+
content: `Message ${i} with some extra content to increase token count`,
167+
}),
168+
);
169+
170+
// @ts-expect-error - typing issue
171+
const result = await middleware.beforeModel?.({ messages });
172+
173+
expect(result).toBeDefined();
174+
expect(result?.messages).toBeDefined();
175+
// Should have summary message + 2 preserved messages
176+
expect(result?.messages.length).toBe(3);
177+
});
178+
179+
it("should not trigger fraction-based summarization when model has no profile", async () => {
180+
const mockBackend = createMockBackend();
181+
182+
// Create a mock model WITHOUT a profile (no maxInputTokens)
183+
const mockModelWithoutProfile = {
184+
async invoke(_messages: any) {
185+
return {
186+
content: "This is a summary of the conversation.",
187+
};
188+
},
189+
// No profile property
190+
};
191+
192+
const middleware = createSummarizationMiddleware({
193+
model: mockModelWithoutProfile as any,
194+
backend: mockBackend,
195+
trigger: { type: "fraction", value: 0.5 },
196+
keep: { type: "messages", value: 2 },
197+
// maxInputTokens is NOT provided and model has no profile
198+
});
199+
200+
// Create messages with content
201+
const messages = Array.from(
202+
{ length: 10 },
203+
(_, i) =>
204+
new HumanMessage({
205+
content: `Message ${i} with some extra content`,
206+
}),
207+
);
208+
209+
// @ts-expect-error - typing issue
210+
const result = await middleware.beforeModel?.({ messages });
211+
212+
// Without maxInputTokens (no explicit option and no model profile), fraction trigger should not fire
213+
expect(result).toBeUndefined();
214+
});
215+
216+
it("should not trigger when token count is below fraction threshold", async () => {
217+
const mockBackend = createMockBackend();
218+
219+
// Create a mock model with high maxInputTokens in profile
220+
const mockModelWithHighLimit = {
221+
profile: {
222+
maxInputTokens: 100000, // Very high threshold
223+
},
224+
async invoke(_messages: any) {
225+
return {
226+
content: "This is a summary of the conversation.",
227+
};
228+
},
229+
};
230+
231+
const middleware = createSummarizationMiddleware({
232+
model: mockModelWithHighLimit as any,
233+
backend: mockBackend,
234+
trigger: { type: "fraction", value: 0.9 }, // 90% of maxInputTokens
235+
keep: { type: "messages", value: 2 },
236+
});
237+
238+
// Create just a few short messages
239+
const messages = [
240+
new HumanMessage({ content: "Hello" }),
241+
new AIMessage({ content: "Hi" }),
242+
];
243+
244+
// @ts-expect-error - typing issue
245+
const result = await middleware.beforeModel?.({ messages });
246+
247+
// Token count is far below 90% of 100000, so should not trigger
248+
expect(result).toBeUndefined();
249+
});
250+
});
251+
133252
describe("keep policy", () => {
134253
it("should preserve specified number of recent messages", async () => {
135254
const mockBackend = createMockBackend();

libs/deepagents/src/middleware/summarization.ts

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import {
5353
} from "langchain";
5454
import { getBufferString } from "@langchain/core/messages";
5555
import type { BaseChatModel } from "@langchain/core/language_models/chat_models";
56-
import { ChatOpenAI } from "@langchain/openai";
56+
import { initChatModel } from "langchain/chat_models/universal";
5757

5858
import type { BackendProtocol, BackendFactory } from "../backends/protocol.js";
5959
import type { StateBackend } from "../backends/state.js";
@@ -259,14 +259,44 @@ export function createSummarizationMiddleware(
259259
return `${historyPathPrefix}/${id}.md`;
260260
}
261261

262+
/**
263+
* Cached resolved model to avoid repeated initChatModel calls
264+
*/
265+
let cachedModel: BaseChatModel | undefined = undefined;
266+
262267
/**
263268
* Resolve the chat model.
269+
* Uses initChatModel to support any model provider from a string name.
270+
* The resolved model is cached for subsequent calls.
264271
*/
265-
function getChatModel(): BaseChatModel {
272+
async function getChatModel(): Promise<BaseChatModel> {
273+
if (cachedModel) {
274+
return cachedModel;
275+
}
276+
266277
if (typeof model === "string") {
267-
return new ChatOpenAI({ modelName: model });
278+
cachedModel = await initChatModel(model);
279+
} else {
280+
cachedModel = model;
281+
}
282+
return cachedModel;
283+
}
284+
285+
/**
286+
* Get the max input tokens from the resolved model's profile.
287+
* Similar to Python's _get_profile_limits.
288+
*/
289+
function getMaxInputTokens(resolvedModel: BaseChatModel): number | undefined {
290+
const profile = resolvedModel.profile;
291+
if (
292+
profile &&
293+
typeof profile === "object" &&
294+
"maxInputTokens" in profile &&
295+
typeof profile.maxInputTokens === "number"
296+
) {
297+
return profile.maxInputTokens;
268298
}
269-
return model;
299+
return undefined;
270300
}
271301

272302
/**
@@ -533,9 +563,10 @@ export function createSummarizationMiddleware(
533563
/**
534564
* Create summary of messages.
535565
*/
536-
async function createSummary(messages: BaseMessage[]): Promise<string> {
537-
const chatModel = getChatModel();
538-
566+
async function createSummary(
567+
messages: BaseMessage[],
568+
chatModel: BaseChatModel,
569+
): Promise<string> {
539570
// Trim messages if too long
540571
let messagesToSummarize = messages;
541572
const tokens = countTokensApproximately(messages);
@@ -605,29 +636,49 @@ ${summary}
605636
return undefined;
606637
}
607638

608-
// Step 1: Truncate args if configured
639+
/**
640+
* Resolve the chat model and get max input tokens from profile
641+
*/
642+
const resolvedModel = await getChatModel();
643+
const maxInputTokens = getMaxInputTokens(resolvedModel);
644+
645+
/**
646+
* Step 1: Truncate args if configured
647+
*/
609648
const { messages: truncatedMessages, modified: argsWereTruncated } =
610-
truncateArgs(messages);
649+
truncateArgs(messages, maxInputTokens);
611650

612-
// Step 2: Check if summarization should happen
651+
/**
652+
* Step 2: Check if summarization should happen
653+
*/
613654
const totalTokens = countTokensApproximately(truncatedMessages);
614655
const shouldDoSummarization = shouldSummarize(
615656
truncatedMessages,
616657
totalTokens,
658+
maxInputTokens,
617659
);
618660

619-
// If only truncation happened (no summarization)
661+
/**
662+
* If only truncation happened (no summarization)
663+
*/
620664
if (argsWereTruncated && !shouldDoSummarization) {
621665
return { messages: truncatedMessages };
622666
}
623667

624-
// If no truncation and no summarization
668+
/**
669+
* If no truncation and no summarization
670+
*/
625671
if (!shouldDoSummarization) {
626672
return undefined;
627673
}
628674

629-
// Step 3: Perform summarization
630-
const cutoffIndex = determineCutoffIndex(truncatedMessages);
675+
/**
676+
* Step 3: Perform summarization
677+
*/
678+
const cutoffIndex = determineCutoffIndex(
679+
truncatedMessages,
680+
maxInputTokens,
681+
);
631682
if (cutoffIndex <= 0) {
632683
if (argsWereTruncated) {
633684
return { messages: truncatedMessages };
@@ -638,7 +689,9 @@ ${summary}
638689
const messagesToSummarize = truncatedMessages.slice(0, cutoffIndex);
639690
const preservedMessages = truncatedMessages.slice(cutoffIndex);
640691

641-
// Offload to backend first
692+
/**
693+
* Offload to backend first
694+
*/
642695
const resolvedBackend = getBackend(state);
643696
const filePath = await offloadToBackend(
644697
resolvedBackend,
@@ -647,14 +700,20 @@ ${summary}
647700
);
648701

649702
if (filePath === null) {
650-
// Offloading failed - don't proceed with summarization
703+
/**
704+
* Offloading failed - don't proceed with summarization
705+
*/
651706
return undefined;
652707
}
653708

654-
// Generate summary
655-
const summary = await createSummary(messagesToSummarize);
709+
/**
710+
* Generate summary
711+
*/
712+
const summary = await createSummary(messagesToSummarize, resolvedModel);
656713

657-
// Build summary message
714+
/**
715+
* Build summary message
716+
*/
658717
const summaryMessage = buildSummaryMessage(summary, filePath);
659718

660719
return {

0 commit comments

Comments
 (0)