diff --git a/src/api/providers/fetchers/__tests__/lmstudio.test.ts b/src/api/providers/fetchers/__tests__/lmstudio.test.ts index 98fe5db32e..ff9a109e50 100644 --- a/src/api/providers/fetchers/__tests__/lmstudio.test.ts +++ b/src/api/providers/fetchers/__tests__/lmstudio.test.ts @@ -118,7 +118,7 @@ describe("LMStudio Fetcher", () => { expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl }) expect(mockListDownloadedModels).toHaveBeenCalledTimes(1) expect(mockListDownloadedModels).toHaveBeenCalledWith("llm") - expect(mockListLoaded).not.toHaveBeenCalled() + expect(mockListLoaded).toHaveBeenCalled() // we now call it to get context data const expectedParsedModel = parseLMStudioModel(mockLLMInfo) expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel }) diff --git a/src/api/providers/fetchers/lmstudio.ts b/src/api/providers/fetchers/lmstudio.ts index 4b7ece71ea..976822c67d 100644 --- a/src/api/providers/fetchers/lmstudio.ts +++ b/src/api/providers/fetchers/lmstudio.ts @@ -1,6 +1,38 @@ import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk" import axios from "axios" +import { flushModels, getModels } from "./modelCache" + +const modelsWithLoadedDetails = new Set() + +export const hasLoadedFullDetails = (modelId: string): boolean => { + return modelsWithLoadedDetails.has(modelId) +} + +export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string): Promise => { + try { + // test the connection to LM Studio first + // errors will be caught further down + await axios.get(`${baseUrl}/v1/models`) + const lmsUrl = baseUrl.replace(/^http:\/\//, "ws://").replace(/^https:\/\//, "wss://") + + const client = new LMStudioClient({ baseUrl: lmsUrl }) + await client.llm.model(modelId) + await flushModels("lmstudio") + await getModels({ provider: "lmstudio" }) // force cache update now + + // Mark this model as having full details loaded + modelsWithLoadedDetails.add(modelId) + } catch (error) { + if (error.code === "ECONNREFUSED") { + console.warn(`Error connecting to LMStudio at ${baseUrl}`) + } else { + console.error( + `Error refreshing LMStudio model details: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`, + ) + } + } +} export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => { // Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models) @@ -19,6 +51,8 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn } export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise> { + // clear the set of models that have full details loaded + modelsWithLoadedDetails.clear() // clearing the input can leave an empty string; use the default in that case baseUrl = baseUrl === "" ? "http://localhost:1234" : baseUrl @@ -46,15 +80,15 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom } } catch (error) { console.warn("Failed to list downloaded models, falling back to loaded models only") + } + // We want to list loaded models *anyway* since they provide valuable extra info (context size) + const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => { + return Promise.all(models.map((m) => m.getModelInfo())) + })) as Array - // Fall back to listing only loaded models - const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => { - return Promise.all(models.map((m) => m.getModelInfo())) - })) as Array - - for (const lmstudioModel of loadedModels) { - models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel) - } + for (const lmstudioModel of loadedModels) { + models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel) + modelsWithLoadedDetails.add(lmstudioModel.modelKey) } } catch (error) { if (error.code === "ECONNREFUSED") { diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index fef700268d..dd6bc01ba1 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -47,7 +47,7 @@ async function readModels(router: RouterName): Promise */ export const getModels = async (options: GetModelsOptions): Promise => { const { provider } = options - let models = memoryCache.get(provider) + let models = getModelsFromCache(provider) if (models) { return models } @@ -113,3 +113,7 @@ export const getModels = async (options: GetModelsOptions): Promise export const flushModels = async (router: RouterName) => { memoryCache.del(router) } + +export function getModelsFromCache(provider: string) { + return memoryCache.get(provider) +} diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index f032e2d560..6c49920bd1 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -13,6 +13,7 @@ import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { getModels, getModelsFromCache } from "./fetchers/modelCache" export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions @@ -131,9 +132,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } override getModel(): { id: string; info: ModelInfo } { - return { - id: this.options.lmStudioModelId || "", - info: openAiModelInfoSaneDefaults, + const models = getModelsFromCache("lmstudio") + if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) { + return { + id: this.options.lmStudioModelId, + info: models[this.options.lmStudioModelId], + } + } else { + return { + id: this.options.lmStudioModelId || "", + info: openAiModelInfoSaneDefaults, + } } } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 280ab61a06..f8a33d175d 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -23,7 +23,6 @@ import { type TerminalActionPromptType, type HistoryItem, type CloudUserInfo, - type MarketplaceItem, requestyDefaultModelId, openRouterDefaultModelId, glamaDefaultModelId, @@ -41,7 +40,7 @@ import { supportPrompt } from "../../shared/support-prompt" import { GlobalFileNames } from "../../shared/globalFileNames" import { ExtensionMessage, MarketplaceInstalledMetadata } from "../../shared/ExtensionMessage" import { Mode, defaultModeSlug, getModeBySlug } from "../../shared/modes" -import { experimentDefault, experiments, EXPERIMENT_IDS } from "../../shared/experiments" +import { experimentDefault } from "../../shared/experiments" import { formatLanguage } from "../../shared/language" import { DEFAULT_WRITE_DELAY_MS } from "@roo-code/types" import { Terminal } from "../../integrations/terminal/Terminal" @@ -71,6 +70,7 @@ import { WebviewMessage } from "../../shared/WebviewMessage" import { EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels" import { ProfileValidator } from "../../shared/ProfileValidator" import { getWorkspaceGitInfo } from "../../utils/git" +import { forceFullModelDetailsLoad, hasLoadedFullDetails } from "../../api/providers/fetchers/lmstudio" /** * https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts @@ -170,6 +170,9 @@ export class ClineProvider // Add this cline instance into the stack that represents the order of all the called tasks. this.clineStack.push(cline) + // Perform special setup provider specific tasks + await this.performPreparationTasks(cline) + // Ensure getState() resolves correctly. const state = await this.getState() @@ -178,6 +181,23 @@ export class ClineProvider } } + async performPreparationTasks(cline: Task) { + // LMStudio: we need to force model loading in order to read its context size; we do it now since we're starting a task with that model selected + if (cline.apiConfiguration && cline.apiConfiguration.apiProvider === "lmstudio") { + try { + if (!hasLoadedFullDetails(cline.apiConfiguration.lmStudioModelId!)) { + await forceFullModelDetailsLoad( + cline.apiConfiguration.lmStudioBaseUrl ?? "http://localhost:1234", + cline.apiConfiguration.lmStudioModelId!, + ) + } + } catch (error) { + this.log(`Failed to load full model details for LM Studio: ${error}`) + vscode.window.showErrorMessage(error.message) + } + } + } + // Removes and destroys the top Cline instance (the current finished task), // activating the previous one (resuming the parent task). async removeClineFromStack() { diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 2e70f80f99..d19ab1e650 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task" import { safeWriteJson } from "../../../utils/safeWriteJson" import { ClineProvider } from "../ClineProvider" +import { AsyncInvokeOutputDataConfig } from "@aws-sdk/client-bedrock-runtime" // Mock setup must come before imports vi.mock("../../prompts/sections/custom-instructions") @@ -2840,6 +2841,33 @@ describe("ClineProvider - Router Models", () => { }, }) }) + + test("handles requestLmStudioModels with proper response", async () => { + await provider.resolveWebviewView(mockWebviewView) + const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as any).mock.calls[0][0] + + vi.spyOn(provider, "getState").mockResolvedValue({ + apiConfiguration: { + lmStudioModelId: "model-1", + lmStudioBaseUrl: "http://localhost:1234", + }, + } as any) + + const mockModels = { + "model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model", supportsPromptCache: false }, + } + const { getModels } = await import("../../../api/providers/fetchers/modelCache") + vi.mocked(getModels).mockResolvedValue(mockModels) + + await messageHandler({ + type: "requestLmStudioModels", + }) + + expect(getModels).toHaveBeenCalledWith({ + provider: "lmstudio", + baseUrl: "http://localhost:1234", + }) + }) }) describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 284ee98944..9a1683e464 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -94,6 +94,48 @@ vi.mock("../../../utils/fs") vi.mock("../../../utils/path") vi.mock("../../../utils/globalContext") +describe("webviewMessageHandler - requestLmStudioModels", () => { + beforeEach(() => { + vi.clearAllMocks() + mockClineProvider.getState = vi.fn().mockResolvedValue({ + apiConfiguration: { + lmStudioModelId: "model-1", + lmStudioBaseUrl: "http://localhost:1234", + }, + }) + }) + + it("successfully fetches models from LMStudio", async () => { + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + "model-2": { + maxTokens: 8192, + contextWindow: 16384, + supportsPromptCache: false, + description: "Test model 2", + }, + } + + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestLmStudioModels", + }) + + expect(mockGetModels).toHaveBeenCalledWith({ provider: "lmstudio", baseUrl: "http://localhost:1234" }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "lmStudioModels", + lmStudioModels: mockModels, + }) + }) +}) + describe("webviewMessageHandler - requestRouterModels", () => { beforeEach(() => { vi.clearAllMocks() diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 763e118125..8fbca9bbcf 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -584,7 +584,7 @@ export const webviewMessageHandler = async ( } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { provider.postMessageToWebview({ type: "lmStudioModels", - lmStudioModels: Object.keys(result.value.models), + lmStudioModels: result.value.models, }) } } else { @@ -648,7 +648,7 @@ export const webviewMessageHandler = async ( if (Object.keys(lmStudioModels).length > 0) { provider.postMessageToWebview({ type: "lmStudioModels", - lmStudioModels: Object.keys(lmStudioModels), + lmStudioModels: lmStudioModels, }) } } catch (error) { diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 67f8782e19..2e2d534c7a 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -16,7 +16,7 @@ import { GitCommit } from "../utils/git" import { McpServer } from "./mcp" import { Mode } from "./modes" -import { RouterModels } from "./api" +import { ModelRecord, RouterModels } from "./api" import type { MarketplaceItem } from "@roo-code/types" // Command interface for frontend/backend communication @@ -146,7 +146,7 @@ export interface ExtensionMessage { routerModels?: RouterModels openAiModels?: string[] ollamaModels?: string[] - lmStudioModels?: string[] + lmStudioModels?: ModelRecord vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] huggingFaceModels?: Array<{ id: string diff --git a/webview-ui/src/components/settings/providers/LMStudio.tsx b/webview-ui/src/components/settings/providers/LMStudio.tsx index a907e43e1b..e3401aa62c 100644 --- a/webview-ui/src/components/settings/providers/LMStudio.tsx +++ b/webview-ui/src/components/settings/providers/LMStudio.tsx @@ -12,6 +12,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" +import { ModelRecord } from "@roo/api" type LMStudioProps = { apiConfiguration: ProviderSettings @@ -21,7 +22,7 @@ type LMStudioProps = { export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudioProps) => { const { t } = useAppTranslation() - const [lmStudioModels, setLmStudioModels] = useState([]) + const [lmStudioModels, setLmStudioModels] = useState({}) const routerModels = useRouterModels() const handleInputChange = useCallback( @@ -41,7 +42,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi switch (message.type) { case "lmStudioModels": { - const newModels = message.lmStudioModels ?? [] + const newModels = message.lmStudioModels ?? {} setLmStudioModels(newModels) } break @@ -62,7 +63,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi if (!selectedModel) return false // Check if model exists in local LM Studio models - if (lmStudioModels.length > 0 && lmStudioModels.includes(selectedModel)) { + if (Object.keys(lmStudioModels).length > 0 && selectedModel in lmStudioModels) { return false // Model is available locally } @@ -83,7 +84,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi if (!draftModel) return false // Check if model exists in local LM Studio models - if (lmStudioModels.length > 0 && lmStudioModels.includes(draftModel)) { + if (Object.keys(lmStudioModels).length > 0 && draftModel in lmStudioModels) { return false // Model is available locally } @@ -125,15 +126,15 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi )} - {lmStudioModels.length > 0 && ( + {Object.keys(lmStudioModels).length > 0 && ( - {lmStudioModels.map((model) => ( + {Object.keys(lmStudioModels).map((model) => ( {model} @@ -175,23 +176,23 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi )} - {lmStudioModels.length > 0 && ( + {Object.keys(lmStudioModels).length > 0 && ( <>
{t("settings:providers.lmStudio.selectDraftModel")}
- {lmStudioModels.map((model) => ( + {Object.keys(lmStudioModels).map((model) => ( {model} ))} - {lmStudioModels.length === 0 && ( + {Object.keys(lmStudioModels).length === 0 && (
+ new Promise((resolve, reject) => { + const cleanup = () => { + window.removeEventListener("message", handler) + } + + const timeout = setTimeout(() => { + cleanup() + reject(new Error("LM Studio models request timed out")) + }, 10000) + + const handler = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "lmStudioModels") { + clearTimeout(timeout) + cleanup() + + if (message.lmStudioModels) { + resolve(message.lmStudioModels) + } else { + reject(new Error("No LMStudio models in response")) + } + } + } + + window.addEventListener("message", handler) + vscode.postMessage({ type: "requestLmStudioModels" }) + }) + +export const useLmStudioModels = (modelId?: string) => + useQuery({ queryKey: ["lmStudioModels"], queryFn: () => (modelId ? getLmStudioModels() : {}) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 8dceb6e117..1f9eda0cbf 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -36,20 +36,24 @@ import { claudeCodeModels, } from "@roo-code/types" -import type { RouterModels } from "@roo/api" +import type { ModelRecord, RouterModels } from "@roo/api" import { useRouterModels } from "./useRouterModels" import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders" +import { useLmStudioModels } from "./useLmStudioModels" export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined + const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined const routerModels = useRouterModels() const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId) + const lmStudioModels = useLmStudioModels(lmStudioModelId) const { id, info } = apiConfiguration && + (typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") && typeof routerModels.data !== "undefined" && typeof openRouterModelProviders.data !== "undefined" ? getSelectedModel({ @@ -57,6 +61,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { apiConfiguration, routerModels: routerModels.data, openRouterModelProviders: openRouterModelProviders.data, + lmStudioModels: lmStudioModels.data, }) : { id: anthropicDefaultModelId, info: undefined } @@ -64,8 +69,14 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { provider, id, info, - isLoading: routerModels.isLoading || openRouterModelProviders.isLoading, - isError: routerModels.isError || openRouterModelProviders.isError, + isLoading: + routerModels.isLoading || + openRouterModelProviders.isLoading || + (apiConfiguration?.lmStudioModelId && lmStudioModels!.isLoading), + isError: + routerModels.isError || + openRouterModelProviders.isError || + (apiConfiguration?.lmStudioModelId && lmStudioModels!.isError), } } @@ -74,11 +85,13 @@ function getSelectedModel({ apiConfiguration, routerModels, openRouterModelProviders, + lmStudioModels, }: { provider: ProviderName apiConfiguration: ProviderSettings routerModels: RouterModels openRouterModelProviders: Record + lmStudioModels: ModelRecord | undefined }): { id: string; info: ModelInfo | undefined } { // the `undefined` case are used to show the invalid selection to prevent // users from seeing the default model if their selection is invalid @@ -204,7 +217,7 @@ function getSelectedModel({ } case "lmstudio": { const id = apiConfiguration.lmStudioModelId ?? "" - const info = routerModels.lmstudio && routerModels.lmstudio[id] + const info = lmStudioModels && lmStudioModels[apiConfiguration.lmStudioModelId!] return { id, info: info || undefined,