Skip to content

Commit ae9a2f1

Browse files
committed
code review
1 parent 4f19605 commit ae9a2f1

File tree

3 files changed

+109
-31
lines changed

3 files changed

+109
-31
lines changed

src/core/webview/__tests__/messageEnhancer.test.ts

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,23 +259,46 @@ describe("MessageEnhancer", () => {
259259
describe("captureTelemetry", () => {
260260
it("should capture telemetry when TelemetryService is available", () => {
261261
const mockTaskId = "task-123"
262-
MessageEnhancer.captureTelemetry(mockTaskId)
262+
const mockCaptureEvent = vi.fn()
263+
vi.mocked(TelemetryService.instance).captureEvent = mockCaptureEvent
264+
265+
MessageEnhancer.captureTelemetry(mockTaskId, true)
263266

264267
expect(TelemetryService.hasInstance).toHaveBeenCalled()
265-
expect(TelemetryService.instance.capturePromptEnhanced).toHaveBeenCalledWith(mockTaskId)
268+
expect(mockCaptureEvent).toHaveBeenCalledWith(expect.any(String), {
269+
taskId: mockTaskId,
270+
includeTaskHistory: true,
271+
})
266272
})
267273

268274
it("should handle missing TelemetryService gracefully", () => {
269275
vi.mocked(TelemetryService).hasInstance = vi.fn().mockReturnValue(false)
270276

271277
// Should not throw
272-
expect(() => MessageEnhancer.captureTelemetry("task-123")).not.toThrow()
278+
expect(() => MessageEnhancer.captureTelemetry("task-123", true)).not.toThrow()
273279
})
274280

275281
it("should work without task ID", () => {
276-
MessageEnhancer.captureTelemetry()
282+
const mockCaptureEvent = vi.fn()
283+
vi.mocked(TelemetryService.instance).captureEvent = mockCaptureEvent
277284

278-
expect(TelemetryService.instance.capturePromptEnhanced).toHaveBeenCalledWith(undefined)
285+
MessageEnhancer.captureTelemetry(undefined, false)
286+
287+
expect(mockCaptureEvent).toHaveBeenCalledWith(expect.any(String), {
288+
includeTaskHistory: false,
289+
})
290+
})
291+
292+
it("should default includeTaskHistory to false when not provided", () => {
293+
const mockCaptureEvent = vi.fn()
294+
vi.mocked(TelemetryService.instance).captureEvent = mockCaptureEvent
295+
296+
MessageEnhancer.captureTelemetry("task-123")
297+
298+
expect(mockCaptureEvent).toHaveBeenCalledWith(expect.any(String), {
299+
taskId: "task-123",
300+
includeTaskHistory: false,
301+
})
279302
})
280303
})
281304

@@ -299,5 +322,44 @@ describe("MessageEnhancer", () => {
299322
expect(history).not.toContain("Tool use")
300323
expect(history.split("\n").length).toBe(3) // Only 3 valid messages
301324
})
325+
326+
it("should handle malformed messages gracefully", () => {
327+
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {})
328+
329+
// Create messages that will cause errors when accessed
330+
const malformedMessages = [
331+
null,
332+
undefined,
333+
{ type: "ask" }, // Missing required properties
334+
"not an object",
335+
] as any
336+
337+
// Access private method through any type assertion for testing
338+
const history = (MessageEnhancer as any).extractTaskHistory(malformedMessages)
339+
340+
// Should return empty string and log error
341+
expect(history).toBe("")
342+
expect(consoleSpy).toHaveBeenCalledWith("Failed to extract task history:", expect.any(Error))
343+
344+
consoleSpy.mockRestore()
345+
})
346+
347+
it("should handle messages with circular references", () => {
348+
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {})
349+
350+
// Create a message with circular reference
351+
const circularMessage: any = { type: "ask", text: "Test" }
352+
circularMessage.self = circularMessage
353+
354+
const messages = [circularMessage] as ClineMessage[]
355+
356+
// Access private method through any type assertion for testing
357+
const history = (MessageEnhancer as any).extractTaskHistory(messages)
358+
359+
// Should handle gracefully
360+
expect(history).toBe("User: Test")
361+
362+
consoleSpy.mockRestore()
363+
})
302364
})
303365
})

src/core/webview/messageEnhancer.ts

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { ProviderSettings, ClineMessage, GlobalState } from "@roo-code/types"
1+
import { ProviderSettings, ClineMessage, GlobalState, TelemetryEventName } from "@roo-code/types"
22
import { TelemetryService } from "@roo-code/telemetry"
33
import { supportPrompt } from "../../shared/support-prompt"
44
import { singleCompletionHandler } from "../../utils/single-completion-handler"
@@ -97,36 +97,47 @@ export class MessageEnhancer {
9797
* @returns Formatted task history string
9898
*/
9999
private static extractTaskHistory(messages: ClineMessage[]): string {
100-
const relevantMessages = messages
101-
.filter((msg) => {
102-
// Include user messages (type: "ask" with text) and assistant messages (type: "say" with say: "text")
103-
if (msg.type === "ask" && msg.text) {
104-
return true
105-
}
106-
if (msg.type === "say" && msg.say === "text" && msg.text) {
107-
return true
108-
}
109-
return false
110-
})
111-
.slice(-10) // Limit to last 10 messages to avoid context explosion
100+
try {
101+
const relevantMessages = messages
102+
.filter((msg) => {
103+
// Include user messages (type: "ask" with text) and assistant messages (type: "say" with say: "text")
104+
if (msg.type === "ask" && msg.text) {
105+
return true
106+
}
107+
if (msg.type === "say" && msg.say === "text" && msg.text) {
108+
return true
109+
}
110+
return false
111+
})
112+
.slice(-10) // Limit to last 10 messages to avoid context explosion
112113

113-
return relevantMessages
114-
.map((msg) => {
115-
const role = msg.type === "ask" ? "User" : "Assistant"
116-
const content = msg.text || ""
117-
// Truncate long messages
118-
return `${role}: ${content.slice(0, 500)}${content.length > 500 ? "..." : ""}`
119-
})
120-
.join("\n")
114+
return relevantMessages
115+
.map((msg) => {
116+
const role = msg.type === "ask" ? "User" : "Assistant"
117+
const content = msg.text || ""
118+
// Truncate long messages
119+
return `${role}: ${content.slice(0, 500)}${content.length > 500 ? "..." : ""}`
120+
})
121+
.join("\n")
122+
} catch (error) {
123+
// Log error but don't fail the enhancement
124+
console.error("Failed to extract task history:", error)
125+
return ""
126+
}
121127
}
122128

123129
/**
124130
* Captures telemetry for prompt enhancement
125131
* @param taskId Optional task ID for telemetry tracking
132+
* @param includeTaskHistory Whether task history was included in the enhancement
126133
*/
127-
static captureTelemetry(taskId?: string): void {
134+
static captureTelemetry(taskId?: string, includeTaskHistory?: boolean): void {
128135
if (TelemetryService.hasInstance()) {
129-
TelemetryService.instance.capturePromptEnhanced(taskId)
136+
// Use captureEvent directly to include the includeTaskHistory property
137+
TelemetryService.instance.captureEvent(TelemetryEventName.PROMPT_ENHANCED, {
138+
...(taskId && { taskId }),
139+
includeTaskHistory: includeTaskHistory ?? false,
140+
})
130141
}
131142
}
132143
}

src/core/webview/webviewMessageHandler.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,8 +1340,13 @@ export const webviewMessageHandler = async (
13401340
if (message.text) {
13411341
try {
13421342
const state = await provider.getState()
1343-
const { apiConfiguration, customSupportPrompts, listApiConfigMeta, enhancementApiConfigId } = state
1344-
const includeTaskHistoryInEnhance = (state as any).includeTaskHistoryInEnhance
1343+
const {
1344+
apiConfiguration,
1345+
customSupportPrompts,
1346+
listApiConfigMeta,
1347+
enhancementApiConfigId,
1348+
includeTaskHistoryInEnhance,
1349+
} = state
13451350

13461351
const currentCline = provider.getCurrentCline()
13471352
const result = await MessageEnhancer.enhanceMessage({
@@ -1357,7 +1362,7 @@ export const webviewMessageHandler = async (
13571362

13581363
if (result.success && result.enhancedText) {
13591364
// Capture telemetry for prompt enhancement
1360-
MessageEnhancer.captureTelemetry(currentCline?.taskId)
1365+
MessageEnhancer.captureTelemetry(currentCline?.taskId, includeTaskHistoryInEnhance)
13611366
await provider.postMessageToWebview({ type: "enhancedPrompt", text: result.enhancedText })
13621367
} else {
13631368
throw new Error(result.error || "Unknown error")

0 commit comments

Comments
 (0)