Skip to content

Commit 6245aa7

Browse files
committed
Add tests
1 parent c9713ea commit 6245aa7

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// npx jest src/core/sliding-window/__tests__/sliding-window.test.ts
2+
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
5+
import { ModelInfo } from "../../../shared/api"
6+
import { truncateConversation, truncateConversationIfNeeded } from "../index"
7+
8+
describe("truncateConversation", () => {
9+
it("should retain the first message", () => {
10+
const messages: Anthropic.Messages.MessageParam[] = [
11+
{ role: "user", content: "First message" },
12+
{ role: "assistant", content: "Second message" },
13+
{ role: "user", content: "Third message" },
14+
]
15+
16+
const result = truncateConversation(messages, 0.5)
17+
18+
// With 2 messages after the first, 0.5 fraction means remove 1 message
19+
// But 1 is odd, so it rounds down to 0 (to make it even)
20+
expect(result.length).toBe(3) // First message + 2 remaining messages
21+
expect(result[0]).toEqual(messages[0])
22+
expect(result[1]).toEqual(messages[1])
23+
expect(result[2]).toEqual(messages[2])
24+
})
25+
26+
it("should remove the specified fraction of messages (rounded to even number)", () => {
27+
const messages: Anthropic.Messages.MessageParam[] = [
28+
{ role: "user", content: "First message" },
29+
{ role: "assistant", content: "Second message" },
30+
{ role: "user", content: "Third message" },
31+
{ role: "assistant", content: "Fourth message" },
32+
{ role: "user", content: "Fifth message" },
33+
]
34+
35+
// 4 messages excluding first, 0.5 fraction = 2 messages to remove
36+
// 2 is already even, so no rounding needed
37+
const result = truncateConversation(messages, 0.5)
38+
39+
expect(result.length).toBe(3)
40+
expect(result[0]).toEqual(messages[0])
41+
expect(result[1]).toEqual(messages[3])
42+
expect(result[2]).toEqual(messages[4])
43+
})
44+
45+
it("should round to an even number of messages to remove", () => {
46+
const messages: Anthropic.Messages.MessageParam[] = [
47+
{ role: "user", content: "First message" },
48+
{ role: "assistant", content: "Second message" },
49+
{ role: "user", content: "Third message" },
50+
{ role: "assistant", content: "Fourth message" },
51+
{ role: "user", content: "Fifth message" },
52+
{ role: "assistant", content: "Sixth message" },
53+
{ role: "user", content: "Seventh message" },
54+
]
55+
56+
// 6 messages excluding first, 0.3 fraction = 1.8 messages to remove
57+
// 1.8 rounds down to 1, then to 0 to make it even
58+
const result = truncateConversation(messages, 0.3)
59+
60+
expect(result.length).toBe(7) // No messages removed
61+
expect(result).toEqual(messages)
62+
})
63+
64+
it("should handle edge case with fracToRemove = 0", () => {
65+
const messages: Anthropic.Messages.MessageParam[] = [
66+
{ role: "user", content: "First message" },
67+
{ role: "assistant", content: "Second message" },
68+
{ role: "user", content: "Third message" },
69+
]
70+
71+
const result = truncateConversation(messages, 0)
72+
73+
expect(result).toEqual(messages)
74+
})
75+
76+
it("should handle edge case with fracToRemove = 1", () => {
77+
const messages: Anthropic.Messages.MessageParam[] = [
78+
{ role: "user", content: "First message" },
79+
{ role: "assistant", content: "Second message" },
80+
{ role: "user", content: "Third message" },
81+
{ role: "assistant", content: "Fourth message" },
82+
]
83+
84+
// 3 messages excluding first, 1.0 fraction = 3 messages to remove
85+
// But 3 is odd, so it rounds down to 2 to make it even
86+
const result = truncateConversation(messages, 1)
87+
88+
expect(result.length).toBe(2)
89+
expect(result[0]).toEqual(messages[0])
90+
expect(result[1]).toEqual(messages[3])
91+
})
92+
})
93+
94+
describe("truncateConversationIfNeeded", () => {
95+
const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({
96+
contextWindow,
97+
supportsPromptCache,
98+
maxTokens,
99+
})
100+
101+
const messages: Anthropic.Messages.MessageParam[] = [
102+
{ role: "user", content: "First message" },
103+
{ role: "assistant", content: "Second message" },
104+
{ role: "user", content: "Third message" },
105+
{ role: "assistant", content: "Fourth message" },
106+
{ role: "user", content: "Fifth message" },
107+
]
108+
109+
it("should not truncate if tokens are below threshold for prompt caching models", () => {
110+
const modelInfo = createModelInfo(200000, true, 50000)
111+
const totalTokens = 100000 // Below threshold
112+
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
113+
expect(result).toEqual(messages)
114+
})
115+
116+
it("should not truncate if tokens are below threshold for non-prompt caching models", () => {
117+
const modelInfo = createModelInfo(200000, false)
118+
const totalTokens = 100000 // Below threshold
119+
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
120+
expect(result).toEqual(messages)
121+
})
122+
123+
it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => {
124+
const modelInfo = createModelInfo(50000, true) // Small context window
125+
const totalTokens = 40001 // Above 80% threshold (40000)
126+
const mockResult = [messages[0], messages[3], messages[4]]
127+
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
128+
expect(result).toEqual(mockResult)
129+
})
130+
})

src/core/sliding-window/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
2+
23
import { ModelInfo } from "../../shared/api"
34

45
/**
@@ -85,7 +86,9 @@ function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number {
8586
* @returns {number} The maximum number of tokens allowed for non-prompt caching models.
8687
*/
8788
function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number {
88-
return Math.max(modelInfo.contextWindow - 40_000, modelInfo.contextWindow * 0.8)
89+
// The buffer needs to be at least as large as `modelInfo.maxTokens`.
90+
const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000
91+
return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8)
8992
}
9093

9194
/**

0 commit comments

Comments
 (0)