Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/api/providers/fetchers/__tests__/lmstudio.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ describe("LMStudio Fetcher", () => {
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
expect(mockListLoaded).not.toHaveBeenCalled()
expect(mockListLoaded).toHaveBeenCalled() // we now call it to get context data

const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
Expand Down
50 changes: 42 additions & 8 deletions src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
import axios from "axios"
import { flushModels, getModels } from "./modelCache"

const modelsWithLoadedDetails = new Set<string>()

export const hasLoadedFullDetails = (modelId: string): boolean => {
return modelsWithLoadedDetails.has(modelId)
}

export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string): Promise<void> => {
try {
// test the connection to LM Studio first
// errors will be caught further down
await axios.get(`${baseUrl}/v1/models`)
const lmsUrl = baseUrl.replace(/^http:\/\//, "ws://").replace(/^https:\/\//, "wss://")

const client = new LMStudioClient({ baseUrl: lmsUrl })
await client.llm.model(modelId)
await flushModels("lmstudio")
await getModels({ provider: "lmstudio" }) // force cache update now

// Mark this model as having full details loaded
modelsWithLoadedDetails.add(modelId)
} catch (error) {
if (error.code === "ECONNREFUSED") {
console.warn(`Error connecting to LMStudio at ${baseUrl}`)
} else {
console.error(
`Error refreshing LMStudio model details: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`,
)
}
}
}

export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
Expand All @@ -19,6 +51,8 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn
}

export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise<Record<string, ModelInfo>> {
// clear the set of models that have full details loaded
modelsWithLoadedDetails.clear()
// clearing the input can leave an empty string; use the default in that case
baseUrl = baseUrl === "" ? "http://localhost:1234" : baseUrl

Expand Down Expand Up @@ -46,15 +80,15 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
}
} catch (error) {
console.warn("Failed to list downloaded models, falling back to loaded models only")
}
// We want to list loaded models *anyway* since they provide valuable extra info (context size)
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

// Fall back to listing only loaded models
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

for (const lmstudioModel of loadedModels) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
}
for (const lmstudioModel of loadedModels) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
modelsWithLoadedDetails.add(lmstudioModel.modelKey)
}
} catch (error) {
if (error.code === "ECONNREFUSED") {
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
*/
export const getModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
const { provider } = options
let models = memoryCache.get<ModelRecord>(provider)
let models = getModelsFromCache(provider)
if (models) {
return models
}
Expand Down Expand Up @@ -113,3 +113,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
export const flushModels = async (router: RouterName) => {
memoryCache.del(router)
}

export function getModelsFromCache(provider: string) {
return memoryCache.get<ModelRecord>(provider)
}
15 changes: 12 additions & 3 deletions src/api/providers/lm-studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ApiStream } from "../transform/stream"

import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getModels, getModelsFromCache } from "./fetchers/modelCache"

export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
Expand Down Expand Up @@ -131,9 +132,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
}

override getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
const models = getModelsFromCache("lmstudio")
if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) {
return {
id: this.options.lmStudioModelId,
info: models[this.options.lmStudioModelId],
}
} else {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
}
}
}

Expand Down
24 changes: 22 additions & 2 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import {
type TerminalActionPromptType,
type HistoryItem,
type CloudUserInfo,
type MarketplaceItem,
requestyDefaultModelId,
openRouterDefaultModelId,
glamaDefaultModelId,
Expand All @@ -41,7 +40,7 @@ import { supportPrompt } from "../../shared/support-prompt"
import { GlobalFileNames } from "../../shared/globalFileNames"
import { ExtensionMessage, MarketplaceInstalledMetadata } from "../../shared/ExtensionMessage"
import { Mode, defaultModeSlug, getModeBySlug } from "../../shared/modes"
import { experimentDefault, experiments, EXPERIMENT_IDS } from "../../shared/experiments"
import { experimentDefault } from "../../shared/experiments"
import { formatLanguage } from "../../shared/language"
import { DEFAULT_WRITE_DELAY_MS } from "@roo-code/types"
import { Terminal } from "../../integrations/terminal/Terminal"
Expand Down Expand Up @@ -71,6 +70,7 @@ import { WebviewMessage } from "../../shared/WebviewMessage"
import { EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels"
import { ProfileValidator } from "../../shared/ProfileValidator"
import { getWorkspaceGitInfo } from "../../utils/git"
import { forceFullModelDetailsLoad, hasLoadedFullDetails } from "../../api/providers/fetchers/lmstudio"

/**
* https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
Expand Down Expand Up @@ -170,6 +170,9 @@ export class ClineProvider
// Add this cline instance into the stack that represents the order of all the called tasks.
this.clineStack.push(cline)

// Perform special setup provider specific tasks
await this.performPreparationTasks(cline)

// Ensure getState() resolves correctly.
const state = await this.getState()

Expand All @@ -178,6 +181,23 @@ export class ClineProvider
}
}

async performPreparationTasks(cline: Task) {
// LMStudio: we need to force model loading in order to read its context size; we do it now since we're starting a task with that model selected
if (cline.apiConfiguration && cline.apiConfiguration.apiProvider === "lmstudio") {
try {
if (!hasLoadedFullDetails(cline.apiConfiguration.lmStudioModelId!)) {
await forceFullModelDetailsLoad(
cline.apiConfiguration.lmStudioBaseUrl ?? "http://localhost:1234",
cline.apiConfiguration.lmStudioModelId!,
)
}
} catch (error) {
this.log(`Failed to load full model details for LM Studio: ${error}`)
vscode.window.showErrorMessage(error.message)
}
}
}

// Removes and destroys the top Cline instance (the current finished task),
// activating the previous one (resuming the parent task).
async removeClineFromStack() {
Expand Down
28 changes: 28 additions & 0 deletions src/core/webview/__tests__/ClineProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task"
import { safeWriteJson } from "../../../utils/safeWriteJson"

import { ClineProvider } from "../ClineProvider"
import { AsyncInvokeOutputDataConfig } from "@aws-sdk/client-bedrock-runtime"

// Mock setup must come before imports
vi.mock("../../prompts/sections/custom-instructions")
Expand Down Expand Up @@ -2840,6 +2841,33 @@ describe("ClineProvider - Router Models", () => {
},
})
})

test("handles requestLmStudioModels with proper response", async () => {
await provider.resolveWebviewView(mockWebviewView)
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as any).mock.calls[0][0]

vi.spyOn(provider, "getState").mockResolvedValue({
apiConfiguration: {
lmStudioModelId: "model-1",
lmStudioBaseUrl: "http://localhost:1234",
},
} as any)

const mockModels = {
"model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model", supportsPromptCache: false },
}
const { getModels } = await import("../../../api/providers/fetchers/modelCache")
vi.mocked(getModels).mockResolvedValue(mockModels)

await messageHandler({
type: "requestLmStudioModels",
})

expect(getModels).toHaveBeenCalledWith({
provider: "lmstudio",
baseUrl: "http://localhost:1234",
})
})
})

describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => {
Expand Down
42 changes: 42 additions & 0 deletions src/core/webview/__tests__/webviewMessageHandler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,48 @@ vi.mock("../../../utils/fs")
vi.mock("../../../utils/path")
vi.mock("../../../utils/globalContext")

describe("webviewMessageHandler - requestLmStudioModels", () => {
beforeEach(() => {
vi.clearAllMocks()
mockClineProvider.getState = vi.fn().mockResolvedValue({
apiConfiguration: {
lmStudioModelId: "model-1",
lmStudioBaseUrl: "http://localhost:1234",
},
})
})

it("successfully fetches models from LMStudio", async () => {
const mockModels: ModelRecord = {
"model-1": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Test model 1",
},
"model-2": {
maxTokens: 8192,
contextWindow: 16384,
supportsPromptCache: false,
description: "Test model 2",
},
}

mockGetModels.mockResolvedValue(mockModels)

await webviewMessageHandler(mockClineProvider, {
type: "requestLmStudioModels",
})

expect(mockGetModels).toHaveBeenCalledWith({ provider: "lmstudio", baseUrl: "http://localhost:1234" })

expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "lmStudioModels",
lmStudioModels: mockModels,
})
})
})

describe("webviewMessageHandler - requestRouterModels", () => {
beforeEach(() => {
vi.clearAllMocks()
Expand Down
4 changes: 2 additions & 2 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ export const webviewMessageHandler = async (
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: Object.keys(result.value.models),
lmStudioModels: result.value.models,
})
}
} else {
Expand Down Expand Up @@ -648,7 +648,7 @@ export const webviewMessageHandler = async (
if (Object.keys(lmStudioModels).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: Object.keys(lmStudioModels),
lmStudioModels: lmStudioModels,
})
}
} catch (error) {
Expand Down
4 changes: 2 additions & 2 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { GitCommit } from "../utils/git"

import { McpServer } from "./mcp"
import { Mode } from "./modes"
import { RouterModels } from "./api"
import { ModelRecord, RouterModels } from "./api"
import type { MarketplaceItem } from "@roo-code/types"

// Command interface for frontend/backend communication
Expand Down Expand Up @@ -146,7 +146,7 @@ export interface ExtensionMessage {
routerModels?: RouterModels
openAiModels?: string[]
ollamaModels?: string[]
lmStudioModels?: string[]
lmStudioModels?: ModelRecord
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
huggingFaceModels?: Array<{
id: string
Expand Down
Loading