Skip to content

Commit 041c28d

Browse files
authored
fix: improve LM Studio model detection to show all downloaded models (#5047)
1 parent 8d94cf8 commit 041c28d

File tree

5 files changed

+86
-12
lines changed

5 files changed

+86
-12
lines changed

src/api/providers/fetchers/__tests__/lmstudio.test.ts

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import axios from "axios"
22
import { vi, describe, it, expect, beforeEach } from "vitest"
3-
import { LMStudioClient, LLM, LLMInstanceInfo } from "@lmstudio/sdk" // LLMInfo is a type
3+
import { LMStudioClient, LLM, LLMInstanceInfo, LLMInfo } from "@lmstudio/sdk"
44
import { getLMStudioModels, parseLMStudioModel } from "../lmstudio"
55
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types" // ModelInfo is a type
66

@@ -11,12 +11,16 @@ const mockedAxios = axios as any
1111
// Mock @lmstudio/sdk
1212
const mockGetModelInfo = vi.fn()
1313
const mockListLoaded = vi.fn()
14+
const mockListDownloadedModels = vi.fn()
1415
vi.mock("@lmstudio/sdk", () => {
1516
return {
1617
LMStudioClient: vi.fn().mockImplementation(() => ({
1718
llm: {
1819
listLoaded: mockListLoaded,
1920
},
21+
system: {
22+
listDownloadedModels: mockListDownloadedModels,
23+
},
2024
})),
2125
}
2226
})
@@ -28,6 +32,7 @@ describe("LMStudio Fetcher", () => {
2832
MockedLMStudioClientConstructor.mockClear()
2933
mockListLoaded.mockClear()
3034
mockGetModelInfo.mockClear()
35+
mockListDownloadedModels.mockClear()
3136
})
3237

3338
describe("parseLMStudioModel", () => {
@@ -88,8 +93,40 @@ describe("LMStudio Fetcher", () => {
8893
trainedForToolUse: false, // Added
8994
}
9095

91-
it("should fetch and parse models successfully", async () => {
96+
it("should fetch downloaded models using system.listDownloadedModels", async () => {
97+
const mockLLMInfo: LLMInfo = {
98+
type: "llm" as const,
99+
modelKey: "mistralai/devstral-small-2505",
100+
format: "safetensors",
101+
displayName: "Devstral Small 2505",
102+
path: "mistralai/devstral-small-2505",
103+
sizeBytes: 13277565112,
104+
architecture: "mistral",
105+
vision: false,
106+
trainedForToolUse: false,
107+
maxContextLength: 131072,
108+
}
109+
110+
mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
111+
mockListDownloadedModels.mockResolvedValueOnce([mockLLMInfo])
112+
113+
const result = await getLMStudioModels(baseUrl)
114+
115+
expect(mockedAxios.get).toHaveBeenCalledTimes(1)
116+
expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
117+
expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
118+
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
119+
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
120+
expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
121+
expect(mockListLoaded).not.toHaveBeenCalled()
122+
123+
const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
124+
expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
125+
})
126+
127+
it("should fall back to listLoaded when listDownloadedModels fails", async () => {
92128
mockedAxios.get.mockResolvedValueOnce({ data: { status: "ok" } })
129+
mockListDownloadedModels.mockRejectedValueOnce(new Error("Method not available"))
93130
mockListLoaded.mockResolvedValueOnce([{ getModelInfo: mockGetModelInfo }])
94131
mockGetModelInfo.mockResolvedValueOnce(mockRawModel)
95132

@@ -99,6 +136,7 @@ describe("LMStudio Fetcher", () => {
99136
expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/v1/models`)
100137
expect(MockedLMStudioClientConstructor).toHaveBeenCalledTimes(1)
101138
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
139+
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
102140
expect(mockListLoaded).toHaveBeenCalledTimes(1)
103141

104142
const expectedParsedModel = parseLMStudioModel(mockRawModel)

src/api/providers/fetchers/lmstudio.ts

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@ import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
22
import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
33
import axios from "axios"
44

5-
export const parseLMStudioModel = (rawModel: LLMInstanceInfo): ModelInfo => {
5+
export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
6+
// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
7+
const contextLength = "contextLength" in rawModel ? rawModel.contextLength : rawModel.maxContextLength
8+
69
const modelInfo: ModelInfo = Object.assign({}, lMStudioDefaultModelInfo, {
710
description: `${rawModel.displayName} - ${rawModel.path}`,
8-
contextWindow: rawModel.contextLength,
11+
contextWindow: contextLength,
912
supportsPromptCache: true,
1013
supportsImages: rawModel.vision,
1114
supportsComputerUse: false,
12-
maxTokens: rawModel.contextLength,
15+
maxTokens: contextLength,
1316
})
1417

1518
return modelInfo
@@ -33,12 +36,25 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
3336
await axios.get(`${baseUrl}/v1/models`)
3437

3538
const client = new LMStudioClient({ baseUrl: lmsUrl })
36-
const response = (await client.llm.listLoaded().then((models: LLM[]) => {
37-
return Promise.all(models.map((m) => m.getModelInfo()))
38-
})) as Array<LLMInstanceInfo>
3939

40-
for (const lmstudioModel of response) {
41-
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
40+
// First, try to get all downloaded models
41+
try {
42+
const downloadedModels = await client.system.listDownloadedModels("llm")
43+
for (const model of downloadedModels) {
44+
// Use the model path as the key since that's what users select
45+
models[model.path] = parseLMStudioModel(model)
46+
}
47+
} catch (error) {
48+
console.warn("Failed to list downloaded models, falling back to loaded models only")
49+
50+
// Fall back to listing only loaded models
51+
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
52+
return Promise.all(models.map((m) => m.getModelInfo()))
53+
})) as Array<LLMInstanceInfo>
54+
55+
for (const lmstudioModel of loadedModels) {
56+
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
57+
}
4258
}
4359
} catch (error) {
4460
if (error.code === "ECONNREFUSED") {

src/core/webview/webviewMessageHandler.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ export const webviewMessageHandler = async (
448448
// Specific handler for Ollama models only
449449
const { apiConfiguration: ollamaApiConfig } = await provider.getState()
450450
try {
451+
// Flush cache first to ensure fresh models
452+
await flushModels("ollama")
453+
451454
const ollamaModels = await getModels({
452455
provider: "ollama",
453456
baseUrl: ollamaApiConfig.ollamaBaseUrl,
@@ -469,6 +472,9 @@ export const webviewMessageHandler = async (
469472
// Specific handler for LM Studio models only
470473
const { apiConfiguration: lmStudioApiConfig } = await provider.getState()
471474
try {
475+
// Flush cache first to ensure fresh models
476+
await flushModels("lmstudio")
477+
472478
const lmStudioModels = await getModels({
473479
provider: "lmstudio",
474480
baseUrl: lmStudioApiConfig.lmStudioBaseUrl,

webview-ui/src/components/settings/providers/LMStudio.tsx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useCallback, useState, useMemo } from "react"
1+
import { useCallback, useState, useMemo, useEffect } from "react"
22
import { useEvent } from "react-use"
33
import { Trans } from "react-i18next"
44
import { Checkbox } from "vscrui"
@@ -9,6 +9,7 @@ import type { ProviderSettings } from "@roo-code/types"
99
import { useAppTranslation } from "@src/i18n/TranslationContext"
1010
import { ExtensionMessage } from "@roo/ExtensionMessage"
1111
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
12+
import { vscode } from "@src/utils/vscode"
1213

1314
import { inputEventTransform } from "../transforms"
1415

@@ -49,6 +50,12 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
4950

5051
useEvent("message", onMessage)
5152

53+
// Refresh models on mount
54+
useEffect(() => {
55+
// Request fresh models - the handler now flushes cache automatically
56+
vscode.postMessage({ type: "requestLmStudioModels" })
57+
}, [])
58+
5259
// Check if the selected model exists in the fetched models
5360
const modelNotAvailable = useMemo(() => {
5461
const selectedModel = apiConfiguration?.lmStudioModelId

webview-ui/src/components/settings/providers/Ollama.tsx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useState, useCallback, useMemo } from "react"
1+
import { useState, useCallback, useMemo, useEffect } from "react"
22
import { useEvent } from "react-use"
33
import { VSCodeTextField, VSCodeRadioGroup, VSCodeRadio } from "@vscode/webview-ui-toolkit/react"
44

@@ -8,6 +8,7 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"
88

99
import { useAppTranslation } from "@src/i18n/TranslationContext"
1010
import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
11+
import { vscode } from "@src/utils/vscode"
1112

1213
import { inputEventTransform } from "../transforms"
1314

@@ -48,6 +49,12 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
4849

4950
useEvent("message", onMessage)
5051

52+
// Refresh models on mount
53+
useEffect(() => {
54+
// Request fresh models - the handler now flushes cache automatically
55+
vscode.postMessage({ type: "requestOllamaModels" })
56+
}, [])
57+
5158
// Check if the selected model exists in the fetched models
5259
const modelNotAvailable = useMemo(() => {
5360
const selectedModel = apiConfiguration?.ollamaModelId

0 commit comments

Comments
 (0)