Skip to content

Commit 596ddac

Browse files
committed
Fix tests
1 parent c1dcbc2 commit 596ddac

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/api/providers/__tests__/vscode-lm.test.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ describe("VsCodeLmHandler", () => {
134134
const mockModel = { ...mockLanguageModelChat }
135135
;(vscode.lm.selectChatModels as jest.Mock).mockResolvedValueOnce([mockModel])
136136
mockLanguageModelChat.countTokens.mockResolvedValue(10)
137+
138+
// Override the default client with our test client
139+
handler["client"] = mockLanguageModelChat
137140
})
138141

139142
it("should stream text responses", async () => {
@@ -229,12 +232,7 @@ describe("VsCodeLmHandler", () => {
229232

230233
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error"))
231234

232-
await expect(async () => {
233-
const stream = handler.createMessage(systemPrompt, messages)
234-
for await (const _ of stream) {
235-
// consume stream
236-
}
237-
}).rejects.toThrow("API Error")
235+
await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
238236
})
239237
})
240238

@@ -253,6 +251,8 @@ describe("VsCodeLmHandler", () => {
253251
})
254252

255253
it("should return fallback model info when no client exists", () => {
254+
// Clear the client first
255+
handler["client"] = null
256256
const model = handler.getModel()
257257
expect(model.id).toBe("test-vendor/test-family")
258258
expect(model.info).toBeDefined()
@@ -276,6 +276,10 @@ describe("VsCodeLmHandler", () => {
276276
})(),
277277
})
278278

279+
// Override the default client with our test client to ensure it uses
280+
// the mock implementation rather than the default fallback
281+
handler["client"] = mockLanguageModelChat
282+
279283
const result = await handler.completePrompt("Test prompt")
280284
expect(result).toBe(responseText)
281285
expect(mockLanguageModelChat.sendRequest).toHaveBeenCalled()
@@ -287,9 +291,11 @@ describe("VsCodeLmHandler", () => {
287291

288292
mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("Completion failed"))
289293

290-
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
291-
"VSCode LM completion error: Completion failed",
292-
)
294+
// Make sure we're using the mock client
295+
handler["client"] = mockLanguageModelChat
296+
297+
const promise = handler.completePrompt("Test prompt")
298+
await expect(promise).rejects.toThrow("VSCode LM completion error: Completion failed")
293299
})
294300
})
295301
})

0 commit comments

Comments
 (0)