diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 2ad6c87ddd..52a468d3ab 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -167,6 +167,7 @@ const lmStudioSchema = baseProviderSettingsSchema.extend({ lmStudioBaseUrl: z.string().optional(), lmStudioDraftModelId: z.string().optional(), lmStudioSpeculativeDecodingEnabled: z.boolean().optional(), + lmStudioTimeoutSeconds: z.number().min(30).max(3600).optional(), }) const geminiSchema = apiModelIdProviderModelSchema.extend({ diff --git a/src/api/providers/__tests__/lmstudio.spec.ts b/src/api/providers/__tests__/lmstudio.spec.ts index 0adebdeea7..57c1c00eca 100644 --- a/src/api/providers/__tests__/lmstudio.spec.ts +++ b/src/api/providers/__tests__/lmstudio.spec.ts @@ -7,6 +7,13 @@ vi.mock("openai", () => { chat: { completions: { create: mockCreate.mockImplementation(async (options) => { + // Check if signal is aborted (for timeout tests) + if (options.signal?.aborted) { + const error = new Error("Request was aborted") + error.name = "AbortError" + throw error + } + if (!options.stream) { return { id: "test-completion", @@ -27,6 +34,13 @@ vi.mock("openai", () => { return { [Symbol.asyncIterator]: async function* () { + // Check if signal is aborted during streaming + if (options.signal?.aborted) { + const error = new Error("Request was aborted") + error.name = "AbortError" + throw error + } + yield { choices: [ { @@ -131,12 +145,17 @@ describe("LmStudioHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.lmStudioModelId, - messages: [{ role: "user", content: "Test prompt" }], - temperature: 0, - stream: false, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.lmStudioModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) }) it("should handle API errors", async () => { @@ -164,4 +183,83 @@ describe("LmStudioHandler", () => { expect(modelInfo.info.contextWindow).toBe(128_000) }) }) + + describe("timeout functionality", () => { + it("should use default timeout of 600 seconds when not configured", () => { + const handlerWithoutTimeout = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234", + }) + + // Verify that the handler was created successfully + expect(handlerWithoutTimeout).toBeInstanceOf(LmStudioHandler) + }) + + it("should use custom timeout when configured", () => { + const customTimeoutHandler = new LmStudioHandler({ + apiModelId: "local-model", + lmStudioModelId: "local-model", + lmStudioBaseUrl: "http://localhost:1234", + lmStudioTimeoutSeconds: 120, // 2 minutes + }) + + // Verify that the handler was created successfully with custom timeout + expect(customTimeoutHandler).toBeInstanceOf(LmStudioHandler) + }) + + it("should handle AbortError and convert to timeout message", async () => { + // Mock an AbortError + const abortError = new Error("Request was aborted") + abortError.name = "AbortError" + mockCreate.mockRejectedValueOnce(abortError) + + await expect(handler.completePrompt("Test prompt")).rejects.toThrow( + "LM Studio request timed out after 600 seconds", + ) + }) + + it("should pass AbortSignal to OpenAI client", async () => { + const result = await handler.completePrompt("Test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "local-model", + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + + expect(result).toBe("Test response") + }) + + it("should pass AbortSignal to streaming requests", async () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello!" }] + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "local-model", + messages: expect.any(Array), + temperature: 0, + stream: true, + }), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + + expect(chunks.length).toBeGreaterThan(0) + }) + }) }) diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index 6c49920bd1..197360e0ef 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -15,6 +15,9 @@ import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { getModels, getModelsFromCache } from "./fetchers/modelCache" +// Default timeout for LM Studio requests (10 minutes) +const LMSTUDIO_DEFAULT_TIMEOUT_SECONDS = 600 + export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI @@ -73,7 +76,19 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan let assistantText = "" + // Create AbortController with configurable timeout + const controller = new AbortController() + let timeoutId: NodeJS.Timeout | undefined + + // Get timeout from settings or use default (10 minutes) + const timeoutSeconds = this.options.lmStudioTimeoutSeconds ?? LMSTUDIO_DEFAULT_TIMEOUT_SECONDS + const timeoutMs = timeoutSeconds * 1000 + try { + timeoutId = setTimeout(() => { + controller.abort() + }, timeoutMs) + const params: OpenAI.Chat.ChatCompletionCreateParamsStreaming & { draft_model?: string } = { model: this.getModel().id, messages: openAiMessages, @@ -85,7 +100,9 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan params.draft_model = this.options.lmStudioDraftModelId } - const results = await this.client.chat.completions.create(params) + const results = await this.client.chat.completions.create(params, { + signal: controller.signal, + }) const matcher = new XmlMatcher( "think", @@ -124,7 +141,20 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan inputTokens, outputTokens, } as const - } catch (error) { + + // Clear timeout after successful completion + clearTimeout(timeoutId) + } catch (error: unknown) { + // Clear timeout on error + clearTimeout(timeoutId) + + // Check if this is an abort error (timeout) + if (error instanceof Error && error.name === "AbortError") { + throw new Error( + `LM Studio request timed out after ${timeoutSeconds} seconds. This can happen with large models that need more processing time. Try increasing the timeout in LM Studio settings or use a smaller model.`, + ) + } + throw new Error( "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.", ) @@ -147,7 +177,19 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } async completePrompt(prompt: string): Promise { + // Create AbortController with configurable timeout + const controller = new AbortController() + let timeoutId: NodeJS.Timeout | undefined + + // Get timeout from settings or use default (10 minutes) + const timeoutSeconds = this.options.lmStudioTimeoutSeconds ?? LMSTUDIO_DEFAULT_TIMEOUT_SECONDS + const timeoutMs = timeoutSeconds * 1000 + try { + timeoutId = setTimeout(() => { + controller.abort() + }, timeoutMs) + // Create params object with optional draft model const params: any = { model: this.getModel().id, @@ -161,9 +203,25 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan params.draft_model = this.options.lmStudioDraftModelId } - const response = await this.client.chat.completions.create(params) + const response = await this.client.chat.completions.create(params, { + signal: controller.signal, + }) + + // Clear timeout after successful completion + clearTimeout(timeoutId) + return response.choices[0]?.message.content || "" - } catch (error) { + } catch (error: unknown) { + // Clear timeout on error + clearTimeout(timeoutId) + + // Check if this is an abort error (timeout) + if (error instanceof Error && error.name === "AbortError") { + throw new Error( + `LM Studio request timed out after ${timeoutSeconds} seconds. This can happen with large models that need more processing time. Try increasing the timeout in LM Studio settings or use a smaller model.`, + ) + } + throw new Error( "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.", )