|
3 | 3 | import { Anthropic } from "@anthropic-ai/sdk" |
4 | 4 | import type { ModelInfo } from "@roo-code/types" |
5 | 5 | import { TelemetryService } from "@roo-code/telemetry" |
| 6 | +import { vi } from "vitest" |
6 | 7 |
|
7 | 8 | import { BaseProvider } from "../../../api/providers/base-provider" |
8 | 9 | import { ApiMessage } from "../../task-persistence/apiMessages" |
9 | 10 | import { summarizeConversation, getMessagesSinceLastSummary, N_MESSAGES_TO_KEEP } from "../index" |
10 | 11 |
|
11 | 12 | // Create a mock ApiHandler for testing |
12 | 13 | class MockApiHandler extends BaseProvider { |
13 | | - createMessage(): any { |
| 14 | + createMessage(systemPrompt?: string, messages?: any[]): any { |
14 | 15 | // Mock implementation for testing - returns an async iterable stream |
15 | 16 | const mockStream = { |
16 | 17 | async *[Symbol.asyncIterator]() { |
@@ -176,7 +177,7 @@ describe("Condense", () => { |
176 | 177 | it("should handle empty summary from API gracefully", async () => { |
177 | 178 | // Mock handler that returns empty summary |
178 | 179 | class EmptyMockApiHandler extends MockApiHandler { |
179 | | - override createMessage(): any { |
| 180 | + override createMessage(systemPrompt?: string, messages?: any[]): any { |
180 | 181 | const mockStream = { |
181 | 182 | async *[Symbol.asyncIterator]() { |
182 | 183 | yield { type: "text", text: "" } |
@@ -204,6 +205,87 @@ describe("Condense", () => { |
204 | 205 | expect(result.messages).toEqual(messages) |
205 | 206 | expect(result.cost).toBeGreaterThan(0) |
206 | 207 | }) |
| 208 | + |
| 209 | + it("should include the initial ask in the summarization input", async () => { |
| 210 | + const initialAsk = "Please help me implement a new authentication system" |
| 211 | + const messages: ApiMessage[] = [ |
| 212 | + { role: "user", content: initialAsk }, |
| 213 | + { role: "assistant", content: "I'll help you implement an authentication system" }, |
| 214 | + { role: "user", content: "Let's start with JWT tokens" }, |
| 215 | + { role: "assistant", content: "Setting up JWT authentication" }, |
| 216 | + { role: "user", content: "Add refresh token support" }, |
| 217 | + { role: "assistant", content: "Adding refresh token logic" }, |
| 218 | + { role: "user", content: "Include rate limiting" }, |
| 219 | + { role: "assistant", content: "Implementing rate limiting" }, |
| 220 | + { role: "user", content: "Add tests" }, |
| 221 | + ] |
| 222 | + |
| 223 | + // Create a spy to capture what's sent to createMessage |
| 224 | + let capturedMessages: any[] = [] |
| 225 | + class SpyApiHandler extends MockApiHandler { |
| 226 | + override createMessage(systemPrompt?: string, messages?: any[]): any { |
| 227 | + capturedMessages = messages || [] |
| 228 | + return super.createMessage(systemPrompt, messages) |
| 229 | + } |
| 230 | + } |
| 231 | + |
| 232 | + const spyHandler = new SpyApiHandler() |
| 233 | + await summarizeConversation(messages, spyHandler, "System prompt", taskId, 5000, false) |
| 234 | + |
| 235 | + // Verify the initial ask is included in the messages sent for summarization |
| 236 | + expect(capturedMessages.length).toBeGreaterThan(0) |
| 237 | + |
| 238 | + // The first user message in the captured messages should be the initial ask |
| 239 | + const firstUserMessage = capturedMessages.find((msg) => msg.role === "user") |
| 240 | + expect(firstUserMessage).toBeDefined() |
| 241 | + expect(firstUserMessage.content).toBe(initialAsk) |
| 242 | + |
| 243 | + // Verify all messages except the last N are included |
| 244 | + const expectedMessagesToSummarize = messages.slice(0, -N_MESSAGES_TO_KEEP) |
| 245 | + // The last message in capturedMessages is the summarization request, so we exclude it |
| 246 | + const actualSummarizedMessages = capturedMessages.slice(0, -1) |
| 247 | + |
| 248 | + // Check that we have the right number of messages |
| 249 | + expect(actualSummarizedMessages.length).toBe(expectedMessagesToSummarize.length) |
| 250 | + |
| 251 | + // Verify the content matches |
| 252 | + for (let i = 0; i < expectedMessagesToSummarize.length; i++) { |
| 253 | + expect(actualSummarizedMessages[i].role).toBe(expectedMessagesToSummarize[i].role) |
| 254 | + expect(actualSummarizedMessages[i].content).toBe(expectedMessagesToSummarize[i].content) |
| 255 | + } |
| 256 | + }) |
| 257 | + |
| 258 | + it("should include initial ask with slash command in summarization", async () => { |
| 259 | + const slashCommand = "/prr #456 - Implement feature X" |
| 260 | + const messages: ApiMessage[] = [ |
| 261 | + { role: "user", content: slashCommand }, |
| 262 | + { role: "assistant", content: "Working on PR #456" }, |
| 263 | + { role: "user", content: "Add error handling" }, |
| 264 | + { role: "assistant", content: "Adding error handling" }, |
| 265 | + { role: "user", content: "Include logging" }, |
| 266 | + { role: "assistant", content: "Adding logging" }, |
| 267 | + { role: "user", content: "Write documentation" }, |
| 268 | + { role: "assistant", content: "Writing docs" }, |
| 269 | + { role: "user", content: "Final review" }, |
| 270 | + ] |
| 271 | + |
| 272 | + // Spy on the API handler to verify what's being sent |
| 273 | + let capturedMessages: any[] = [] |
| 274 | + class SpyApiHandler extends MockApiHandler { |
| 275 | + override createMessage(systemPrompt?: string, messages?: any[]): any { |
| 276 | + capturedMessages = messages || [] |
| 277 | + return super.createMessage(systemPrompt, messages) |
| 278 | + } |
| 279 | + } |
| 280 | + |
| 281 | + const spyHandler = new SpyApiHandler() |
| 282 | + await summarizeConversation(messages, spyHandler, "System prompt", taskId, 5000, false) |
| 283 | + |
| 284 | + // Verify the slash command is in the summarization input |
| 285 | + const firstMessage = capturedMessages[0] |
| 286 | + expect(firstMessage.role).toBe("user") |
| 287 | + expect(firstMessage.content).toBe(slashCommand) |
| 288 | + }) |
207 | 289 | }) |
208 | 290 |
|
209 | 291 | describe("getMessagesSinceLastSummary", () => { |
|
0 commit comments