Skip to content

Commit 9c16b8a

Browse files
committed
feat: add custom context window support for AI providers
- Add customModelInfo fields to Groq and Vertex provider schemas - Create reusable ContextWindow component for UI - Update Groq and Vertex provider UIs to include context window input - Update API handlers to use custom context window when provided - Allow users to override default context window size for better control Fixes #7209
1 parent fd3535c commit 9c16b8a

File tree

7 files changed

+192
-5
lines changed

7 files changed

+192
-5
lines changed

packages/types/src/provider-settings.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ const vertexSchema = apiModelIdProviderModelSchema.extend({
147147
vertexJsonCredentials: z.string().optional(),
148148
vertexProjectId: z.string().optional(),
149149
vertexRegion: z.string().optional(),
150+
vertexCustomModelInfo: modelInfoSchema.nullish(),
150151
})
151152

152153
const openAiSchema = baseProviderSettingsSchema.extend({
@@ -248,6 +249,7 @@ const xaiSchema = apiModelIdProviderModelSchema.extend({
248249

249250
const groqSchema = apiModelIdProviderModelSchema.extend({
250251
groqApiKey: z.string().optional(),
252+
groqCustomModelInfo: modelInfoSchema.nullish(),
251253
})
252254

253255
const huggingFaceSchema = baseProviderSettingsSchema.extend({

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,28 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
130130
? (this.options.apiModelId as ModelName)
131131
: this.defaultProviderModelId
132132

133-
return { id, info: this.providerModels[id] }
133+
const defaultInfo = this.providerModels[id]
134+
135+
// Check if there's custom model info for this provider
136+
// This allows Groq and other providers to override context window
137+
const customModelInfo = this.getCustomModelInfo()
138+
139+
const info: ModelInfo = customModelInfo
140+
? {
141+
...defaultInfo,
142+
...customModelInfo,
143+
// Ensure required fields are present
144+
maxTokens: customModelInfo.maxTokens ?? defaultInfo.maxTokens,
145+
contextWindow: customModelInfo.contextWindow ?? defaultInfo.contextWindow,
146+
supportsPromptCache: customModelInfo.supportsPromptCache ?? defaultInfo.supportsPromptCache,
147+
}
148+
: defaultInfo
149+
150+
return { id, info }
151+
}
152+
153+
protected getCustomModelInfo(): ModelInfo | undefined {
154+
// Override in subclasses to provide custom model info
155+
return undefined
134156
}
135157
}

src/api/providers/groq.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
1+
import { type GroqModelId, type ModelInfo, groqDefaultModelId, groqModels } from "@roo-code/types"
22

33
import type { ApiHandlerOptions } from "../../shared/api"
44

@@ -16,4 +16,8 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
1616
defaultTemperature: 0.5,
1717
})
1818
}
19+
20+
protected override getCustomModelInfo(): ModelInfo | undefined {
21+
return this.options.groqCustomModelInfo ?? undefined
22+
}
1923
}

src/api/providers/vertex.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,21 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand
1515
override getModel() {
1616
const modelId = this.options.apiModelId
1717
let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
18-
const info: ModelInfo = vertexModels[id]
18+
const defaultInfo: ModelInfo = vertexModels[id]
19+
20+
// Apply custom model info if provided
21+
const customModelInfo = this.options.vertexCustomModelInfo
22+
const info: ModelInfo = customModelInfo
23+
? {
24+
...defaultInfo,
25+
...customModelInfo,
26+
// Ensure required fields are present
27+
maxTokens: customModelInfo.maxTokens ?? defaultInfo.maxTokens,
28+
contextWindow: customModelInfo.contextWindow ?? defaultInfo.contextWindow,
29+
supportsPromptCache: customModelInfo.supportsPromptCache ?? defaultInfo.supportsPromptCache,
30+
}
31+
: defaultInfo
32+
1933
const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options })
2034

2135
// The `:thinking` suffix indicates that the model is a "Hybrid"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import { useCallback } from "react"
2+
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
3+
4+
import type { ModelInfo } from "@roo-code/types"
5+
6+
type ContextWindowProps = {
7+
customModelInfo?: ModelInfo | null
8+
defaultContextWindow?: number
9+
onContextWindowChange: (contextWindow: number | undefined) => void
10+
label?: string
11+
placeholder?: string
12+
helperText?: string
13+
}
14+
15+
const inputEventTransform = (event: any) => (event as { target: HTMLInputElement })?.target?.value
16+
17+
export const ContextWindow = ({
18+
customModelInfo,
19+
defaultContextWindow,
20+
onContextWindowChange,
21+
label,
22+
placeholder,
23+
helperText,
24+
}: ContextWindowProps) => {
25+
const handleContextWindowChange = useCallback(
26+
(event: any) => {
27+
const value = inputEventTransform(event)?.trim()
28+
29+
if (value === "") {
30+
// Clear custom context window
31+
onContextWindowChange(undefined)
32+
} else {
33+
const numValue = parseInt(value, 10)
34+
if (!isNaN(numValue) && numValue > 0) {
35+
onContextWindowChange(numValue)
36+
}
37+
}
38+
},
39+
[onContextWindowChange],
40+
)
41+
42+
const currentValue = customModelInfo?.contextWindow?.toString() || ""
43+
const placeholderText = placeholder || defaultContextWindow?.toString() || "128000"
44+
const labelText = label || "Context Window Size"
45+
const helperTextContent = helperText || "Custom context window size in tokens (leave empty to use default)"
46+
47+
return (
48+
<>
49+
<VSCodeTextField
50+
value={currentValue}
51+
onInput={handleContextWindowChange}
52+
placeholder={placeholderText}
53+
className="w-full">
54+
<label className="block font-medium mb-1">{labelText}</label>
55+
</VSCodeTextField>
56+
{helperTextContent && (
57+
<div className="text-sm text-vscode-descriptionForeground -mt-2">{helperTextContent}</div>
58+
)}
59+
</>
60+
)
61+
}

webview-ui/src/components/settings/providers/Groq.tsx

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import { useCallback } from "react"
22
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
33

4-
import type { ProviderSettings } from "@roo-code/types"
4+
import type { ProviderSettings, ModelInfo } from "@roo-code/types"
55

66
import { useAppTranslation } from "@src/i18n/TranslationContext"
77
import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink"
8+
import { ContextWindow } from "@src/components/common/ContextWindow"
89

910
import { inputEventTransform } from "../transforms"
1011

@@ -27,6 +28,42 @@ export const Groq = ({ apiConfiguration, setApiConfigurationField }: GroqProps)
2728
[setApiConfigurationField],
2829
)
2930

31+
const handleContextWindowChange = useCallback(
32+
(contextWindow: number | undefined) => {
33+
const currentModelInfo = apiConfiguration?.groqCustomModelInfo
34+
const updatedModelInfo: ModelInfo | undefined = contextWindow
35+
? {
36+
maxTokens: currentModelInfo?.maxTokens ?? null,
37+
contextWindow,
38+
supportsPromptCache: currentModelInfo?.supportsPromptCache ?? false,
39+
// Preserve other fields if they exist
40+
...(currentModelInfo && {
41+
maxThinkingTokens: currentModelInfo.maxThinkingTokens,
42+
supportsImages: currentModelInfo.supportsImages,
43+
supportsComputerUse: currentModelInfo.supportsComputerUse,
44+
supportsVerbosity: currentModelInfo.supportsVerbosity,
45+
supportsReasoningBudget: currentModelInfo.supportsReasoningBudget,
46+
requiredReasoningBudget: currentModelInfo.requiredReasoningBudget,
47+
supportsReasoningEffort: currentModelInfo.supportsReasoningEffort,
48+
supportedParameters: currentModelInfo.supportedParameters,
49+
inputPrice: currentModelInfo.inputPrice,
50+
outputPrice: currentModelInfo.outputPrice,
51+
cacheWritesPrice: currentModelInfo.cacheWritesPrice,
52+
cacheReadsPrice: currentModelInfo.cacheReadsPrice,
53+
description: currentModelInfo.description,
54+
reasoningEffort: currentModelInfo.reasoningEffort,
55+
minTokensPerCachePoint: currentModelInfo.minTokensPerCachePoint,
56+
maxCachePoints: currentModelInfo.maxCachePoints,
57+
cachableFields: currentModelInfo.cachableFields,
58+
tiers: currentModelInfo.tiers,
59+
}),
60+
}
61+
: undefined
62+
setApiConfigurationField("groqCustomModelInfo", updatedModelInfo)
63+
},
64+
[apiConfiguration?.groqCustomModelInfo, setApiConfigurationField],
65+
)
66+
3067
return (
3168
<>
3269
<VSCodeTextField
@@ -45,6 +82,11 @@ export const Groq = ({ apiConfiguration, setApiConfigurationField }: GroqProps)
4582
{t("settings:providers.getGroqApiKey")}
4683
</VSCodeButtonLink>
4784
)}
85+
<ContextWindow
86+
customModelInfo={apiConfiguration?.groqCustomModelInfo}
87+
defaultContextWindow={128000}
88+
onContextWindowChange={handleContextWindowChange}
89+
/>
4890
</>
4991
)
5092
}

webview-ui/src/components/settings/providers/Vertex.tsx

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import { useCallback } from "react"
22
import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
33

4-
import { type ProviderSettings, VERTEX_REGIONS } from "@roo-code/types"
4+
import { type ProviderSettings, type ModelInfo, VERTEX_REGIONS } from "@roo-code/types"
55

66
import { useAppTranslation } from "@src/i18n/TranslationContext"
77
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@src/components/ui"
8+
import { ContextWindow } from "@src/components/common/ContextWindow"
89

910
import { inputEventTransform } from "../transforms"
1011

@@ -27,6 +28,42 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro
2728
[setApiConfigurationField],
2829
)
2930

31+
const handleContextWindowChange = useCallback(
32+
(contextWindow: number | undefined) => {
33+
const currentModelInfo = apiConfiguration?.vertexCustomModelInfo
34+
const updatedModelInfo: ModelInfo | undefined = contextWindow
35+
? {
36+
maxTokens: currentModelInfo?.maxTokens ?? null,
37+
contextWindow,
38+
supportsPromptCache: currentModelInfo?.supportsPromptCache ?? false,
39+
// Preserve other fields if they exist
40+
...(currentModelInfo && {
41+
maxThinkingTokens: currentModelInfo.maxThinkingTokens,
42+
supportsImages: currentModelInfo.supportsImages,
43+
supportsComputerUse: currentModelInfo.supportsComputerUse,
44+
supportsVerbosity: currentModelInfo.supportsVerbosity,
45+
supportsReasoningBudget: currentModelInfo.supportsReasoningBudget,
46+
requiredReasoningBudget: currentModelInfo.requiredReasoningBudget,
47+
supportsReasoningEffort: currentModelInfo.supportsReasoningEffort,
48+
supportedParameters: currentModelInfo.supportedParameters,
49+
inputPrice: currentModelInfo.inputPrice,
50+
outputPrice: currentModelInfo.outputPrice,
51+
cacheWritesPrice: currentModelInfo.cacheWritesPrice,
52+
cacheReadsPrice: currentModelInfo.cacheReadsPrice,
53+
description: currentModelInfo.description,
54+
reasoningEffort: currentModelInfo.reasoningEffort,
55+
minTokensPerCachePoint: currentModelInfo.minTokensPerCachePoint,
56+
maxCachePoints: currentModelInfo.maxCachePoints,
57+
cachableFields: currentModelInfo.cachableFields,
58+
tiers: currentModelInfo.tiers,
59+
}),
60+
}
61+
: undefined
62+
setApiConfigurationField("vertexCustomModelInfo", updatedModelInfo)
63+
},
64+
[apiConfiguration?.vertexCustomModelInfo, setApiConfigurationField],
65+
)
66+
3067
return (
3168
<>
3269
<div className="text-sm text-vscode-descriptionForeground">
@@ -91,6 +128,11 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro
91128
</SelectContent>
92129
</Select>
93130
</div>
131+
<ContextWindow
132+
customModelInfo={apiConfiguration?.vertexCustomModelInfo}
133+
defaultContextWindow={128000}
134+
onContextWindowChange={handleContextWindowChange}
135+
/>
94136
</>
95137
)
96138
}

0 commit comments

Comments
 (0)