Skip to content

Commit 1a013b4

Browse files
pwilkinellipsis-dev[bot]daniel-lxs
authored
fix: LM Studio model context length (#5075) (#6183)
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> Co-authored-by: Daniel <[email protected]> Co-authored-by: Daniel Riccio <[email protected]>
1 parent 74672fa commit 1a013b4

File tree

12 files changed

+223
-32
lines changed

12 files changed

+223
-32
lines changed

src/api/providers/fetchers/__tests__/lmstudio.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ describe("LMStudio Fetcher", () => {
118118
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
119119
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
120120
expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
121-
expect(mockListLoaded).not.toHaveBeenCalled()
121+
expect(mockListLoaded).toHaveBeenCalled() // we now call it to get context data
122122

123123
const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
124124
expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })

src/api/providers/fetchers/lmstudio.ts

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,38 @@
11
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
22
import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
33
import axios from "axios"
4+
import { flushModels, getModels } from "./modelCache"
5+
6+
const modelsWithLoadedDetails = new Set<string>()
7+
8+
export const hasLoadedFullDetails = (modelId: string): boolean => {
9+
return modelsWithLoadedDetails.has(modelId)
10+
}
11+
12+
export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string): Promise<void> => {
13+
try {
14+
// test the connection to LM Studio first
15+
// errors will be caught further down
16+
await axios.get(`${baseUrl}/v1/models`)
17+
const lmsUrl = baseUrl.replace(/^http:\/\//, "ws://").replace(/^https:\/\//, "wss://")
18+
19+
const client = new LMStudioClient({ baseUrl: lmsUrl })
20+
await client.llm.model(modelId)
21+
await flushModels("lmstudio")
22+
await getModels({ provider: "lmstudio" }) // force cache update now
23+
24+
// Mark this model as having full details loaded
25+
modelsWithLoadedDetails.add(modelId)
26+
} catch (error) {
27+
if (error.code === "ECONNREFUSED") {
28+
console.warn(`Error connecting to LMStudio at ${baseUrl}`)
29+
} else {
30+
console.error(
31+
`Error refreshing LMStudio model details: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`,
32+
)
33+
}
34+
}
35+
}
436

537
export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
638
// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
@@ -19,6 +51,8 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn
1951
}
2052

2153
export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise<Record<string, ModelInfo>> {
54+
// clear the set of models that have full details loaded
55+
modelsWithLoadedDetails.clear()
2256
// clearing the input can leave an empty string; use the default in that case
2357
baseUrl = baseUrl === "" ? "http://localhost:1234" : baseUrl
2458

@@ -46,15 +80,15 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
4680
}
4781
} catch (error) {
4882
console.warn("Failed to list downloaded models, falling back to loaded models only")
83+
}
84+
// We want to list loaded models *anyway* since they provide valuable extra info (context size)
85+
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
86+
return Promise.all(models.map((m) => m.getModelInfo()))
87+
})) as Array<LLMInstanceInfo>
4988

50-
// Fall back to listing only loaded models
51-
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
52-
return Promise.all(models.map((m) => m.getModelInfo()))
53-
})) as Array<LLMInstanceInfo>
54-
55-
for (const lmstudioModel of loadedModels) {
56-
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
57-
}
89+
for (const lmstudioModel of loadedModels) {
90+
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
91+
modelsWithLoadedDetails.add(lmstudioModel.modelKey)
5892
}
5993
} catch (error) {
6094
if (error.code === "ECONNREFUSED") {

src/api/providers/fetchers/modelCache.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
4747
*/
4848
export const getModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
4949
const { provider } = options
50-
let models = memoryCache.get<ModelRecord>(provider)
50+
let models = getModelsFromCache(provider)
5151
if (models) {
5252
return models
5353
}
@@ -113,3 +113,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
113113
export const flushModels = async (router: RouterName) => {
114114
memoryCache.del(router)
115115
}
116+
117+
export function getModelsFromCache(provider: string) {
118+
return memoryCache.get<ModelRecord>(provider)
119+
}

src/api/providers/lm-studio.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import { ApiStream } from "../transform/stream"
1313

1414
import { BaseProvider } from "./base-provider"
1515
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
16+
import { getModels, getModelsFromCache } from "./fetchers/modelCache"
1617

1718
export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
1819
protected options: ApiHandlerOptions
@@ -131,9 +132,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
131132
}
132133

133134
override getModel(): { id: string; info: ModelInfo } {
134-
return {
135-
id: this.options.lmStudioModelId || "",
136-
info: openAiModelInfoSaneDefaults,
135+
const models = getModelsFromCache("lmstudio")
136+
if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) {
137+
return {
138+
id: this.options.lmStudioModelId,
139+
info: models[this.options.lmStudioModelId],
140+
}
141+
} else {
142+
return {
143+
id: this.options.lmStudioModelId || "",
144+
info: openAiModelInfoSaneDefaults,
145+
}
137146
}
138147
}
139148

src/core/webview/ClineProvider.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ import { WebviewMessage } from "../../shared/WebviewMessage"
7070
import { EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels"
7171
import { ProfileValidator } from "../../shared/ProfileValidator"
7272
import { getWorkspaceGitInfo } from "../../utils/git"
73+
import { forceFullModelDetailsLoad, hasLoadedFullDetails } from "../../api/providers/fetchers/lmstudio"
7374

7475
/**
7576
* https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -163,6 +164,9 @@ export class ClineProvider
163164
// Add this cline instance into the stack that represents the order of all the called tasks.
164165
this.clineStack.push(cline)
165166

167+
// Perform special setup provider specific tasks
168+
await this.performPreparationTasks(cline)
169+
166170
// Ensure getState() resolves correctly.
167171
const state = await this.getState()
168172

@@ -171,6 +175,23 @@ export class ClineProvider
171175
}
172176
}
173177

178+
async performPreparationTasks(cline: Task) {
179+
// LMStudio: we need to force model loading in order to read its context size; we do it now since we're starting a task with that model selected
180+
if (cline.apiConfiguration && cline.apiConfiguration.apiProvider === "lmstudio") {
181+
try {
182+
if (!hasLoadedFullDetails(cline.apiConfiguration.lmStudioModelId!)) {
183+
await forceFullModelDetailsLoad(
184+
cline.apiConfiguration.lmStudioBaseUrl ?? "http://localhost:1234",
185+
cline.apiConfiguration.lmStudioModelId!,
186+
)
187+
}
188+
} catch (error) {
189+
this.log(`Failed to load full model details for LM Studio: ${error}`)
190+
vscode.window.showErrorMessage(error.message)
191+
}
192+
}
193+
}
194+
174195
// Removes and destroys the top Cline instance (the current finished task),
175196
// activating the previous one (resuming the parent task).
176197
async removeClineFromStack() {

src/core/webview/__tests__/ClineProvider.spec.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task"
1616
import { safeWriteJson } from "../../../utils/safeWriteJson"
1717

1818
import { ClineProvider } from "../ClineProvider"
19+
import { AsyncInvokeOutputDataConfig } from "@aws-sdk/client-bedrock-runtime"
1920

2021
// Mock setup must come before imports
2122
vi.mock("../../prompts/sections/custom-instructions")
@@ -2840,6 +2841,33 @@ describe("ClineProvider - Router Models", () => {
28402841
},
28412842
})
28422843
})
2844+
2845+
test("handles requestLmStudioModels with proper response", async () => {
2846+
await provider.resolveWebviewView(mockWebviewView)
2847+
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as any).mock.calls[0][0]
2848+
2849+
vi.spyOn(provider, "getState").mockResolvedValue({
2850+
apiConfiguration: {
2851+
lmStudioModelId: "model-1",
2852+
lmStudioBaseUrl: "http://localhost:1234",
2853+
},
2854+
} as any)
2855+
2856+
const mockModels = {
2857+
"model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model", supportsPromptCache: false },
2858+
}
2859+
const { getModels } = await import("../../../api/providers/fetchers/modelCache")
2860+
vi.mocked(getModels).mockResolvedValue(mockModels)
2861+
2862+
await messageHandler({
2863+
type: "requestLmStudioModels",
2864+
})
2865+
2866+
expect(getModels).toHaveBeenCalledWith({
2867+
provider: "lmstudio",
2868+
baseUrl: "http://localhost:1234",
2869+
})
2870+
})
28432871
})
28442872

28452873
describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => {

src/core/webview/__tests__/webviewMessageHandler.spec.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,48 @@ vi.mock("../../../utils/fs")
9494
vi.mock("../../../utils/path")
9595
vi.mock("../../../utils/globalContext")
9696

97+
describe("webviewMessageHandler - requestLmStudioModels", () => {
98+
beforeEach(() => {
99+
vi.clearAllMocks()
100+
mockClineProvider.getState = vi.fn().mockResolvedValue({
101+
apiConfiguration: {
102+
lmStudioModelId: "model-1",
103+
lmStudioBaseUrl: "http://localhost:1234",
104+
},
105+
})
106+
})
107+
108+
it("successfully fetches models from LMStudio", async () => {
109+
const mockModels: ModelRecord = {
110+
"model-1": {
111+
maxTokens: 4096,
112+
contextWindow: 8192,
113+
supportsPromptCache: false,
114+
description: "Test model 1",
115+
},
116+
"model-2": {
117+
maxTokens: 8192,
118+
contextWindow: 16384,
119+
supportsPromptCache: false,
120+
description: "Test model 2",
121+
},
122+
}
123+
124+
mockGetModels.mockResolvedValue(mockModels)
125+
126+
await webviewMessageHandler(mockClineProvider, {
127+
type: "requestLmStudioModels",
128+
})
129+
130+
expect(mockGetModels).toHaveBeenCalledWith({ provider: "lmstudio", baseUrl: "http://localhost:1234" })
131+
132+
expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
133+
type: "lmStudioModels",
134+
lmStudioModels: mockModels,
135+
})
136+
})
137+
})
138+
97139
describe("webviewMessageHandler - requestRouterModels", () => {
98140
beforeEach(() => {
99141
vi.clearAllMocks()

src/core/webview/webviewMessageHandler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ export const webviewMessageHandler = async (
584584
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
585585
provider.postMessageToWebview({
586586
type: "lmStudioModels",
587-
lmStudioModels: Object.keys(result.value.models),
587+
lmStudioModels: result.value.models,
588588
})
589589
}
590590
} else {
@@ -648,7 +648,7 @@ export const webviewMessageHandler = async (
648648
if (Object.keys(lmStudioModels).length > 0) {
649649
provider.postMessageToWebview({
650650
type: "lmStudioModels",
651-
lmStudioModels: Object.keys(lmStudioModels),
651+
lmStudioModels: lmStudioModels,
652652
})
653653
}
654654
} catch (error) {

src/shared/ExtensionMessage.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import { GitCommit } from "../utils/git"
1616

1717
import { McpServer } from "./mcp"
1818
import { Mode } from "./modes"
19-
import { RouterModels } from "./api"
19+
import { ModelRecord, RouterModels } from "./api"
2020
import type { MarketplaceItem } from "@roo-code/types"
2121

2222
// Command interface for frontend/backend communication
@@ -146,7 +146,7 @@ export interface ExtensionMessage {
146146
routerModels?: RouterModels
147147
openAiModels?: string[]
148148
ollamaModels?: string[]
149-
lmStudioModels?: string[]
149+
lmStudioModels?: ModelRecord
150150
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
151151
huggingFaceModels?: Array<{
152152
id: string

webview-ui/src/components/settings/providers/LMStudio.tsx

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
1212
import { vscode } from "@src/utils/vscode"
1313

1414
import { inputEventTransform } from "../transforms"
15+
import { ModelRecord } from "@roo/api"
1516

1617
type LMStudioProps = {
1718
apiConfiguration: ProviderSettings
@@ -21,7 +22,7 @@ type LMStudioProps = {
2122
export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudioProps) => {
2223
const { t } = useAppTranslation()
2324

24-
const [lmStudioModels, setLmStudioModels] = useState<string[]>([])
25+
const [lmStudioModels, setLmStudioModels] = useState<ModelRecord>({})
2526
const routerModels = useRouterModels()
2627

2728
const handleInputChange = useCallback(
@@ -41,7 +42,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
4142
switch (message.type) {
4243
case "lmStudioModels":
4344
{
44-
const newModels = message.lmStudioModels ?? []
45+
const newModels = message.lmStudioModels ?? {}
4546
setLmStudioModels(newModels)
4647
}
4748
break
@@ -62,7 +63,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
6263
if (!selectedModel) return false
6364

6465
// Check if model exists in local LM Studio models
65-
if (lmStudioModels.length > 0 && lmStudioModels.includes(selectedModel)) {
66+
if (Object.keys(lmStudioModels).length > 0 && selectedModel in lmStudioModels) {
6667
return false // Model is available locally
6768
}
6869

@@ -83,7 +84,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
8384
if (!draftModel) return false
8485

8586
// Check if model exists in local LM Studio models
86-
if (lmStudioModels.length > 0 && lmStudioModels.includes(draftModel)) {
87+
if (Object.keys(lmStudioModels).length > 0 && draftModel in lmStudioModels) {
8788
return false // Model is available locally
8889
}
8990

@@ -125,15 +126,15 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
125126
</div>
126127
</div>
127128
)}
128-
{lmStudioModels.length > 0 && (
129+
{Object.keys(lmStudioModels).length > 0 && (
129130
<VSCodeRadioGroup
130131
value={
131-
lmStudioModels.includes(apiConfiguration?.lmStudioModelId || "")
132+
(apiConfiguration?.lmStudioModelId || "") in lmStudioModels
132133
? apiConfiguration?.lmStudioModelId
133134
: ""
134135
}
135136
onChange={handleInputChange("lmStudioModelId")}>
136-
{lmStudioModels.map((model) => (
137+
{Object.keys(lmStudioModels).map((model) => (
137138
<VSCodeRadio key={model} value={model} checked={apiConfiguration?.lmStudioModelId === model}>
138139
{model}
139140
</VSCodeRadio>
@@ -175,23 +176,23 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
175176
</div>
176177
)}
177178
</div>
178-
{lmStudioModels.length > 0 && (
179+
{Object.keys(lmStudioModels).length > 0 && (
179180
<>
180181
<div className="font-medium">{t("settings:providers.lmStudio.selectDraftModel")}</div>
181182
<VSCodeRadioGroup
182183
value={
183-
lmStudioModels.includes(apiConfiguration?.lmStudioDraftModelId || "")
184+
(apiConfiguration?.lmStudioDraftModelId || "") in lmStudioModels
184185
? apiConfiguration?.lmStudioDraftModelId
185186
: ""
186187
}
187188
onChange={handleInputChange("lmStudioDraftModelId")}>
188-
{lmStudioModels.map((model) => (
189+
{Object.keys(lmStudioModels).map((model) => (
189190
<VSCodeRadio key={`draft-${model}`} value={model}>
190191
{model}
191192
</VSCodeRadio>
192193
))}
193194
</VSCodeRadioGroup>
194-
{lmStudioModels.length === 0 && (
195+
{Object.keys(lmStudioModels).length === 0 && (
195196
<div
196197
className="text-sm rounded-xs p-2"
197198
style={{

0 commit comments

Comments
 (0)