Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/api/providers/fetchers/__tests__/lmstudio.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ describe("LMStudio Fetcher", () => {
expect(MockedLMStudioClientConstructor).toHaveBeenCalledWith({ baseUrl: lmsUrl })
expect(mockListDownloadedModels).toHaveBeenCalledTimes(1)
expect(mockListDownloadedModels).toHaveBeenCalledWith("llm")
expect(mockListLoaded).not.toHaveBeenCalled()
expect(mockListLoaded).toHaveBeenCalled() // we now call it to get context data

const expectedParsedModel = parseLMStudioModel(mockLLMInfo)
expect(result).toEqual({ [mockLLMInfo.path]: expectedParsedModel })
Expand Down
47 changes: 39 additions & 8 deletions src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
import { ModelInfo, lMStudioDefaultModelInfo } from "@roo-code/types"
import { LLM, LLMInfo, LLMInstanceInfo, LMStudioClient } from "@lmstudio/sdk"
import axios from "axios"
import { flushModels, getModels } from "./modelCache"

const modelsWithLoadedDetails = new Set<string>()

export const hasLoadedFullDetails = (modelId: string): boolean => {
return modelsWithLoadedDetails.has(modelId)
}

export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string): Promise<void> => {
try {
// test the connection to LM Studio first
// errors will be caught further down
await axios.get(`${baseUrl}/v1/models`)
const lmsUrl = baseUrl.replace(/^http:\/\//, "ws://").replace(/^https:\/\//, "wss://")

const client = new LMStudioClient({ baseUrl: lmsUrl })
await client.llm.model(modelId)
await flushModels("lmstudio")
await getModels({ provider: "lmstudio" }) // force cache update now
} catch (error) {
if (error.code === "ECONNREFUSED") {
console.warn(`Error connecting to LMStudio at ${baseUrl}`)
} else {
console.error(
`Error refreshing LMStudio model details: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`,
)
}
}
}

export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelInfo => {
// Handle both LLMInstanceInfo (from loaded models) and LLMInfo (from downloaded models)
Expand All @@ -19,6 +48,8 @@ export const parseLMStudioModel = (rawModel: LLMInstanceInfo | LLMInfo): ModelIn
}

export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Promise<Record<string, ModelInfo>> {
// clear the set of models that have full details loaded
modelsWithLoadedDetails.clear()
// clearing the input can leave an empty string; use the default in that case
baseUrl = baseUrl === "" ? "http://localhost:1234" : baseUrl

Expand Down Expand Up @@ -46,15 +77,15 @@ export async function getLMStudioModels(baseUrl = "http://localhost:1234"): Prom
}
} catch (error) {
console.warn("Failed to list downloaded models, falling back to loaded models only")
}
// We want to list loaded models *anyway* since they provide valuable extra info (context size)
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

// Fall back to listing only loaded models
const loadedModels = (await client.llm.listLoaded().then((models: LLM[]) => {
return Promise.all(models.map((m) => m.getModelInfo()))
})) as Array<LLMInstanceInfo>

for (const lmstudioModel of loadedModels) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
}
for (const lmstudioModel of loadedModels) {
models[lmstudioModel.modelKey] = parseLMStudioModel(lmstudioModel)
modelsWithLoadedDetails.add(lmstudioModel.modelKey)
}
} catch (error) {
if (error.code === "ECONNREFUSED") {
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
*/
export const getModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
const { provider } = options
let models = memoryCache.get<ModelRecord>(provider)
let models = getModelsFromCache(provider)
if (models) {
return models
}
Expand Down Expand Up @@ -113,3 +113,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
export const flushModels = async (router: RouterName) => {
memoryCache.del(router)
}

export function getModelsFromCache(provider: string) {
return memoryCache.get<ModelRecord>(provider)
}
15 changes: 12 additions & 3 deletions src/api/providers/lm-studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ApiStream } from "../transform/stream"

import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getModels, getModelsFromCache } from "./fetchers/modelCache"

export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
Expand Down Expand Up @@ -131,9 +132,17 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
}

override getModel(): { id: string; info: ModelInfo } {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
const models = getModelsFromCache("lmstudio")
if (models && this.options.lmStudioModelId && models[this.options.lmStudioModelId]) {
return {
id: this.options.lmStudioModelId,
info: models[this.options.lmStudioModelId],
}
} else {
return {
id: this.options.lmStudioModelId || "",
info: openAiModelInfoSaneDefaults,
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ import { WebviewMessage } from "../../shared/WebviewMessage"
import { EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels"
import { ProfileValidator } from "../../shared/ProfileValidator"
import { getWorkspaceGitInfo } from "../../utils/git"
import { LmStudioHandler } from "../../api/providers"
import { forceFullModelDetailsLoad, hasLoadedFullDetails } from "../../api/providers/fetchers/lmstudio"

/**
* https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
Expand Down Expand Up @@ -170,6 +172,9 @@ export class ClineProvider
// Add this cline instance into the stack that represents the order of all the called tasks.
this.clineStack.push(cline)

// Perform special setup provider specific tasks
await this.performPreparationTasks(cline)

// Ensure getState() resolves correctly.
const state = await this.getState()

Expand All @@ -178,6 +183,18 @@ export class ClineProvider
}
}

async performPreparationTasks(cline: Task) {
// LMStudio: we need to force model loading in order to read its context size; we do it now since we're starting a task with that model selected
if (cline.apiConfiguration && cline.apiConfiguration.apiProvider === "lmstudio") {
if (!hasLoadedFullDetails(cline.apiConfiguration.lmStudioModelId!)) {
forceFullModelDetailsLoad(
cline.apiConfiguration.lmStudioBaseUrl ?? "http://localhost:1234",
cline.apiConfiguration.lmStudioModelId!,
)
}
}
}

// Removes and destroys the top Cline instance (the current finished task),
// activating the previous one (resuming the parent task).
async removeClineFromStack() {
Expand Down
28 changes: 28 additions & 0 deletions src/core/webview/__tests__/ClineProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { Task, TaskOptions } from "../../task/Task"
import { safeWriteJson } from "../../../utils/safeWriteJson"

import { ClineProvider } from "../ClineProvider"
import { AsyncInvokeOutputDataConfig } from "@aws-sdk/client-bedrock-runtime"

// Mock setup must come before imports
vi.mock("../../prompts/sections/custom-instructions")
Expand Down Expand Up @@ -2576,6 +2577,33 @@ describe("ClineProvider - Router Models", () => {
},
})
})

test("handles requestLmStudioModels with proper response", async () => {
await provider.resolveWebviewView(mockWebviewView)
const messageHandler = (mockWebviewView.webview.onDidReceiveMessage as any).mock.calls[0][0]

vi.spyOn(provider, "getState").mockResolvedValue({
apiConfiguration: {
lmStudioModelId: "model-1",
lmStudioBaseUrl: "http://localhost:1234",
},
} as any)

const mockModels = {
"model-1": { maxTokens: 4096, contextWindow: 8192, description: "Test model", supportsPromptCache: false },
}
const { getModels } = await import("../../../api/providers/fetchers/modelCache")
vi.mocked(getModels).mockResolvedValue(mockModels)

await messageHandler({
type: "requestLmStudioModels",
})

expect(getModels).toHaveBeenCalledWith({
provider: "lmstudio",
baseUrl: "http://localhost:1234",
})
})
})

describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => {
Expand Down
42 changes: 42 additions & 0 deletions src/core/webview/__tests__/webviewMessageHandler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,48 @@ vi.mock("../../../utils/fs")
vi.mock("../../../utils/path")
vi.mock("../../../utils/globalContext")

describe("webviewMessageHandler - requestLmStudioModels", () => {
beforeEach(() => {
vi.clearAllMocks()
mockClineProvider.getState = vi.fn().mockResolvedValue({
apiConfiguration: {
lmStudioModelId: "model-1",
lmStudioBaseUrl: "http://localhost:1234",
},
})
})

it("successfully fetches models from LMStudio", async () => {
const mockModels: ModelRecord = {
"model-1": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Test model 1",
},
"model-2": {
maxTokens: 8192,
contextWindow: 16384,
supportsPromptCache: false,
description: "Test model 2",
},
}

mockGetModels.mockResolvedValue(mockModels)

await webviewMessageHandler(mockClineProvider, {
type: "requestLmStudioModels",
})

expect(mockGetModels).toHaveBeenCalledWith({ provider: "lmstudio", baseUrl: "http://localhost:1234" })

expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "lmStudioModels",
lmStudioModels: mockModels,
})
})
})

describe("webviewMessageHandler - requestRouterModels", () => {
beforeEach(() => {
vi.clearAllMocks()
Expand Down
4 changes: 2 additions & 2 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ export const webviewMessageHandler = async (
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: Object.keys(result.value.models),
lmStudioModels: result.value.models,
})
}
} else {
Expand Down Expand Up @@ -648,7 +648,7 @@ export const webviewMessageHandler = async (
if (Object.keys(lmStudioModels).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: Object.keys(lmStudioModels),
lmStudioModels: lmStudioModels,
})
}
} catch (error) {
Expand Down
4 changes: 2 additions & 2 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { GitCommit } from "../utils/git"

import { McpServer } from "./mcp"
import { Mode } from "./modes"
import { RouterModels } from "./api"
import { ModelRecord, RouterModels } from "./api"
import type { MarketplaceItem } from "@roo-code/types"

// Type for marketplace installed metadata
Expand Down Expand Up @@ -135,7 +135,7 @@ export interface ExtensionMessage {
routerModels?: RouterModels
openAiModels?: string[]
ollamaModels?: string[]
lmStudioModels?: string[]
lmStudioModels?: ModelRecord
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
huggingFaceModels?: Array<{
_id: string
Expand Down
23 changes: 12 additions & 11 deletions webview-ui/src/components/settings/providers/LMStudio.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
import { vscode } from "@src/utils/vscode"

import { inputEventTransform } from "../transforms"
import { ModelRecord } from "@roo/api"

type LMStudioProps = {
apiConfiguration: ProviderSettings
Expand All @@ -21,7 +22,7 @@ type LMStudioProps = {
export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudioProps) => {
const { t } = useAppTranslation()

const [lmStudioModels, setLmStudioModels] = useState<string[]>([])
const [lmStudioModels, setLmStudioModels] = useState<ModelRecord>({})
const routerModels = useRouterModels()

const handleInputChange = useCallback(
Expand All @@ -41,7 +42,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
switch (message.type) {
case "lmStudioModels":
{
const newModels = message.lmStudioModels ?? []
const newModels = message.lmStudioModels ?? {}
setLmStudioModels(newModels)
}
break
Expand All @@ -62,7 +63,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
if (!selectedModel) return false

// Check if model exists in local LM Studio models
if (lmStudioModels.length > 0 && lmStudioModels.includes(selectedModel)) {
if (Object.keys(lmStudioModels).length > 0 && selectedModel in lmStudioModels) {
return false // Model is available locally
}

Expand All @@ -83,7 +84,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
if (!draftModel) return false

// Check if model exists in local LM Studio models
if (lmStudioModels.length > 0 && lmStudioModels.includes(draftModel)) {
if (Object.keys(lmStudioModels).length > 0 && draftModel in lmStudioModels) {
return false // Model is available locally
}

Expand Down Expand Up @@ -125,15 +126,15 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
</div>
</div>
)}
{lmStudioModels.length > 0 && (
{Object.keys(lmStudioModels).length > 0 && (
<VSCodeRadioGroup
value={
lmStudioModels.includes(apiConfiguration?.lmStudioModelId || "")
(apiConfiguration?.lmStudioModelId || "") in lmStudioModels
? apiConfiguration?.lmStudioModelId
: ""
}
onChange={handleInputChange("lmStudioModelId")}>
{lmStudioModels.map((model) => (
{Object.keys(lmStudioModels).map((model) => (
<VSCodeRadio key={model} value={model} checked={apiConfiguration?.lmStudioModelId === model}>
{model}
</VSCodeRadio>
Expand Down Expand Up @@ -175,23 +176,23 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi
</div>
)}
</div>
{lmStudioModels.length > 0 && (
{Object.keys(lmStudioModels).length > 0 && (
<>
<div className="font-medium">{t("settings:providers.lmStudio.selectDraftModel")}</div>
<VSCodeRadioGroup
value={
lmStudioModels.includes(apiConfiguration?.lmStudioDraftModelId || "")
(apiConfiguration?.lmStudioDraftModelId || "") in lmStudioModels
? apiConfiguration?.lmStudioDraftModelId
: ""
}
onChange={handleInputChange("lmStudioDraftModelId")}>
{lmStudioModels.map((model) => (
{Object.keys(lmStudioModels).map((model) => (
<VSCodeRadio key={`draft-${model}`} value={model}>
{model}
</VSCodeRadio>
))}
</VSCodeRadioGroup>
{lmStudioModels.length === 0 && (
{Object.keys(lmStudioModels).length === 0 && (
<div
className="text-sm rounded-xs p-2"
style={{
Expand Down
Loading