Skip to content

Commit d4918c3

Browse files
committed
Switched GetModelsOptions to a discriminated union based on api provider. Adjusted related code and tests to match
1 parent bd1f93a commit d4918c3

File tree

12 files changed

+210
-128
lines changed

12 files changed

+210
-128
lines changed

src/api/providers/__tests__/litellm.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ describe("LiteLLMHandler", () => {
7777
it("returns correct model info when modelId is provided and found in getModels", async () => {
7878
const handler = new LiteLLMHandler(defaultMockOptions)
7979
const result = await handler.fetchModel()
80-
expect(mockGetModels).toHaveBeenCalledWith(
81-
"litellm",
82-
defaultMockOptions.litellmApiKey,
83-
defaultMockOptions.litellmBaseUrl,
84-
)
80+
expect(mockGetModels).toHaveBeenCalledWith({
81+
provider: "litellm",
82+
apiKey: defaultMockOptions.litellmApiKey,
83+
baseUrl: defaultMockOptions.litellmBaseUrl,
84+
})
8585
expect(result).toEqual({ id: defaultMockOptions.litellmModelId, info: mockModelInfo })
8686
})
8787

src/api/providers/fetchers/modelCache.ts

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import NodeCache from "node-cache"
55

66
import { ContextProxy } from "../../../core/config/ContextProxy"
77
import { getCacheDirectoryPath } from "../../../shared/storagePathManager"
8-
import { RouterName, ModelRecord } from "../../../shared/api"
8+
import { RouterName, ModelRecord, GetModelsOptions } from "../../../shared/api"
99
import { fileExistsAtPath } from "../../../utils/fs"
1010

1111
import { getOpenRouterModels } from "./openrouter"
@@ -30,18 +30,6 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
3030
return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined
3131
}
3232

33-
/**
34-
* Options for fetching models from different routers.
35-
* This is a discriminated union type where the router property determines
36-
* which other properties are required.
37-
*/
38-
export type GetModelsOptions =
39-
| { router: "openrouter" }
40-
| { router: "glama" }
41-
| { router: "requesty"; apiKey?: string }
42-
| { router: "unbound"; apiKey?: string }
43-
| { router: "litellm"; apiKey: string; baseUrl: string }
44-
4533
/**
4634
* Get models from the cache or fetch them from the provider and cache them.
4735
* There are two caches:
@@ -52,14 +40,14 @@ export type GetModelsOptions =
5240
* @returns The models from the cache or the fetched models.
5341
*/
5442
export const getModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
55-
const { router } = options
56-
let models = memoryCache.get<ModelRecord>(router)
43+
const { provider } = options
44+
let models = memoryCache.get<ModelRecord>(provider)
5745
if (models) {
5846
return models
5947
}
6048

6149
try {
62-
switch (router) {
50+
switch (provider) {
6351
case "openrouter":
6452
models = await getOpenRouterModels()
6553
break
@@ -80,26 +68,26 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
8068
break
8169
default:
8270
// Ensures router is exhaustively checked if RouterName is a strict union
83-
const exhaustiveCheck: never = router
71+
const exhaustiveCheck: never = provider
8472
throw new Error(`Unknown router: ${exhaustiveCheck}`)
8573
}
8674

8775
// Cache the fetched models (even if empty, to signify a successful fetch with no models)
88-
memoryCache.set(router, models)
89-
await writeModels(router, models).catch((err) =>
90-
console.error(`[getModels] Error writing ${router} models to file cache:`, err),
76+
memoryCache.set(provider, models)
77+
await writeModels(provider, models).catch((err) =>
78+
console.error(`[getModels] Error writing ${provider} models to file cache:`, err),
9179
)
9280

9381
try {
94-
models = await readModels(router)
82+
models = await readModels(provider)
9583
// console.log(`[getModels] read ${router} models from file cache`)
9684
} catch (error) {
97-
console.error(`[getModels] error reading ${router} models from file cache`, error)
85+
console.error(`[getModels] error reading ${provider} models from file cache`, error)
9886
}
9987
return models || {}
10088
} catch (error) {
10189
// Log the error and re-throw it so the caller can handle it (e.g., show a UI message).
102-
console.error(`[getModels] Failed to fetch models for ${router}:`, error)
90+
console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error)
10391

10492
throw error // Re-throw the original error to be handled by the caller.
10593
}

src/api/providers/litellm.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
1919
options,
2020
name: "litellm",
2121
baseURL: `${options.litellmBaseUrl || "http://localhost:4000"}`,
22-
apiKey: options.litellmApiKey || "dummy-key",
22+
apiKey: options.litellmApiKey || "sk-1234",
2323
modelId: options.litellmModelId,
2424
defaultModelId: litellmDefaultModelId,
2525
defaultModelInfo: litellmDefaultModelInfo,

src/api/providers/openrouter.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
171171

172172
public async fetchModel() {
173173
const [models, endpoints] = await Promise.all([
174-
getModels("openrouter"),
174+
getModels({ provider: "openrouter" }),
175175
getModelEndpoints({
176176
router: "openrouter",
177177
modelId: this.options.openRouterModelId,

src/api/providers/requesty.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
4545
}
4646

4747
public async fetchModel() {
48-
this.models = await getModels("requesty")
48+
this.models = await getModels({ provider: "requesty", apiKey: this.options.requestyApiKey })
4949
return this.getModel()
5050
}
5151

src/api/providers/router-provider.ts

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import OpenAI from "openai"
22

3-
import { ApiHandlerOptions, RouterName, ModelRecord, ModelInfo } from "../../shared/api"
3+
import { ApiHandlerOptions, RouterName, ModelRecord, ModelInfo, GetModelsOptions } from "../../shared/api"
44
import { BaseProvider } from "./base-provider"
55
import { getModels } from "./fetchers/modelCache"
66

@@ -51,7 +51,31 @@ export abstract class RouterProvider extends BaseProvider {
5151
}
5252

5353
public async fetchModel() {
54-
this.models = await getModels(this.name, this.apiKey, this.baseURL)
54+
// Create the appropriate options based on router type
55+
let options: GetModelsOptions
56+
57+
switch (this.name) {
58+
case "openrouter":
59+
options = { provider: "openrouter" }
60+
break
61+
case "glama":
62+
options = { provider: "glama" }
63+
break
64+
case "requesty":
65+
options = { provider: "requesty", apiKey: this.apiKey }
66+
break
67+
case "unbound":
68+
options = { provider: "unbound", apiKey: this.apiKey }
69+
break
70+
case "litellm":
71+
options = { provider: "litellm", apiKey: this.apiKey, baseUrl: this.baseURL }
72+
break
73+
default:
74+
const exhaustiveCheck: never = this.name
75+
throw new Error(`Unknown provider: ${exhaustiveCheck}`)
76+
}
77+
78+
this.models = await getModels(options)
5579
return this.getModel()
5680
}
5781

src/core/webview/__tests__/webviewMessageHandler.test.ts

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ describe("webviewMessageHandler", () => {
3333
describe("requestRouterModels", () => {
3434
test("handles all successful model fetches correctly", async () => {
3535
// Mock all getModels calls to succeed with different data
36-
;(getModels as jest.Mock).mockImplementation((router) => {
36+
;(getModels as jest.Mock).mockImplementation((options) => {
37+
const provider = options.provider
3738
return Promise.resolve({
38-
[`${router}-model-1`]: { name: `${router} Model 1` },
39-
[`${router}-model-2`]: { name: `${router} Model 2` },
39+
[`${provider}-model-1`]: { name: `${provider} Model 1` },
40+
[`${provider}-model-2`]: { name: `${provider} Model 2` },
4041
})
4142
})
4243

@@ -75,14 +76,15 @@ describe("webviewMessageHandler", () => {
7576

7677
test("handles some failed model fetches correctly", async () => {
7778
// Mock some getModels calls to succeed and others to fail
78-
;(getModels as jest.Mock).mockImplementation((router) => {
79-
if (router === "openrouter" || router === "litellm") {
79+
;(getModels as jest.Mock).mockImplementation((options) => {
80+
const provider = options.provider
81+
if (provider === "openrouter" || provider === "litellm") {
8082
return Promise.resolve({
81-
[`${router}-model-1`]: { name: `${router} Model 1` },
83+
[`${provider}-model-1`]: { name: `${provider} Model 1` },
8284
})
8385
}
84-
// For other routers, throw an error
85-
return Promise.reject(new Error(`Failed to fetch ${router} models`))
86+
// For other providers, throw an error
87+
return Promise.reject(new Error(`Failed to fetch ${provider} models`))
8688
})
8789

8890
// Call the handler
@@ -129,4 +131,46 @@ describe("webviewMessageHandler", () => {
129131
})
130132
})
131133
})
134+
135+
describe("requestProviderModels", () => {
136+
test("when getModels succeeds, it posts a providerModelsResponse with models", async () => {
137+
const mockLiteLLMModels = { "litellm-model-1": { name: "LiteLLM Model 1" } }
138+
;(getModels as jest.Mock).mockResolvedValueOnce(mockLiteLLMModels)
139+
140+
await webviewMessageHandler(mockProvider as any, {
141+
type: "requestProviderModels",
142+
payload: { provider: "litellm", apiKey: "test-key", baseUrl: "test-url" },
143+
})
144+
145+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
146+
type: "providerModelsResponse",
147+
payload: {
148+
provider: "litellm",
149+
models: mockLiteLLMModels,
150+
error: undefined, // Explicitly check error is undefined on success
151+
},
152+
})
153+
expect(getModels).toHaveBeenCalledWith({ provider: "litellm", apiKey: "test-key", baseUrl: "test-url" })
154+
})
155+
156+
test("when getModels fails, it posts a providerModelsResponse with an error and empty models", async () => {
157+
const errorMessage = "Failed to fetch LiteLLM models: No response from server."
158+
;(getModels as jest.Mock).mockRejectedValueOnce(new Error(errorMessage))
159+
160+
await webviewMessageHandler(mockProvider as any, {
161+
type: "requestProviderModels",
162+
payload: { provider: "litellm", apiKey: "test-key", baseUrl: "test-url" },
163+
})
164+
165+
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith({
166+
type: "providerModelsResponse",
167+
payload: {
168+
provider: "litellm",
169+
models: {},
170+
error: errorMessage,
171+
},
172+
})
173+
expect(getModels).toHaveBeenCalledWith({ provider: "litellm", apiKey: "test-key", baseUrl: "test-url" })
174+
})
175+
})
132176
})

0 commit comments

Comments
 (0)