diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index c22683117f..15addbb0bb 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -147,6 +147,7 @@ const vertexSchema = apiModelIdProviderModelSchema.extend({ vertexJsonCredentials: z.string().optional(), vertexProjectId: z.string().optional(), vertexRegion: z.string().optional(), + vertexCustomModelInfo: modelInfoSchema.nullish(), }) const openAiSchema = baseProviderSettingsSchema.extend({ @@ -248,6 +249,7 @@ const xaiSchema = apiModelIdProviderModelSchema.extend({ const groqSchema = apiModelIdProviderModelSchema.extend({ groqApiKey: z.string().optional(), + groqCustomModelInfo: modelInfoSchema.nullish(), }) const huggingFaceSchema = baseProviderSettingsSchema.extend({ diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index f196b5f309..80d59fd71a 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -130,6 +130,28 @@ export abstract class BaseOpenAiCompatibleProvider ? (this.options.apiModelId as ModelName) : this.defaultProviderModelId - return { id, info: this.providerModels[id] } + const defaultInfo = this.providerModels[id] + + // Check if there's custom model info for this provider + // This allows Groq and other providers to override context window + const customModelInfo = this.getCustomModelInfo() + + const info: ModelInfo = customModelInfo + ? { + ...defaultInfo, + ...customModelInfo, + // Ensure required fields are present + maxTokens: customModelInfo.maxTokens ?? defaultInfo.maxTokens, + contextWindow: customModelInfo.contextWindow ?? defaultInfo.contextWindow, + supportsPromptCache: customModelInfo.supportsPromptCache ?? defaultInfo.supportsPromptCache, + } + : defaultInfo + + return { id, info } + } + + protected getCustomModelInfo(): ModelInfo | undefined { + // Override in subclasses to provide custom model info + return undefined } } diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index 7583edc51c..323e15e7a3 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -1,4 +1,4 @@ -import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" +import { type GroqModelId, type ModelInfo, groqDefaultModelId, groqModels } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" @@ -16,4 +16,8 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: 0.5, }) } + + protected override getCustomModelInfo(): ModelInfo | undefined { + return this.options.groqCustomModelInfo ?? undefined + } } diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 2c077d97b7..05c61fff63 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -15,7 +15,21 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand override getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId - const info: ModelInfo = vertexModels[id] + const defaultInfo: ModelInfo = vertexModels[id] + + // Apply custom model info if provided + const customModelInfo = this.options.vertexCustomModelInfo + const info: ModelInfo = customModelInfo + ? { + ...defaultInfo, + ...customModelInfo, + // Ensure required fields are present + maxTokens: customModelInfo.maxTokens ?? defaultInfo.maxTokens, + contextWindow: customModelInfo.contextWindow ?? defaultInfo.contextWindow, + supportsPromptCache: customModelInfo.supportsPromptCache ?? defaultInfo.supportsPromptCache, + } + : defaultInfo + const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options }) // The `:thinking` suffix indicates that the model is a "Hybrid" diff --git a/webview-ui/src/components/common/ContextWindow.tsx b/webview-ui/src/components/common/ContextWindow.tsx new file mode 100644 index 0000000000..bdde76490a --- /dev/null +++ b/webview-ui/src/components/common/ContextWindow.tsx @@ -0,0 +1,61 @@ +import { useCallback } from "react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import type { ModelInfo } from "@roo-code/types" + +type ContextWindowProps = { + customModelInfo?: ModelInfo | null + defaultContextWindow?: number + onContextWindowChange: (contextWindow: number | undefined) => void + label?: string + placeholder?: string + helperText?: string +} + +const inputEventTransform = (event: any) => (event as { target: HTMLInputElement })?.target?.value + +export const ContextWindow = ({ + customModelInfo, + defaultContextWindow, + onContextWindowChange, + label, + placeholder, + helperText, +}: ContextWindowProps) => { + const handleContextWindowChange = useCallback( + (event: any) => { + const value = inputEventTransform(event)?.trim() + + if (value === "") { + // Clear custom context window + onContextWindowChange(undefined) + } else { + const numValue = parseInt(value, 10) + if (!isNaN(numValue) && numValue > 0) { + onContextWindowChange(numValue) + } + } + }, + [onContextWindowChange], + ) + + const currentValue = customModelInfo?.contextWindow?.toString() || "" + const placeholderText = placeholder || defaultContextWindow?.toString() || "128000" + const labelText = label || "Context Window Size" + const helperTextContent = helperText || "Custom context window size in tokens (leave empty to use default)" + + return ( + <> + + + + {helperTextContent && ( +
{helperTextContent}
+ )} + + ) +} diff --git a/webview-ui/src/components/settings/providers/Groq.tsx b/webview-ui/src/components/settings/providers/Groq.tsx index a8a910d1ac..e9ac86047d 100644 --- a/webview-ui/src/components/settings/providers/Groq.tsx +++ b/webview-ui/src/components/settings/providers/Groq.tsx @@ -1,10 +1,11 @@ import { useCallback } from "react" import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" -import type { ProviderSettings } from "@roo-code/types" +import type { ProviderSettings, ModelInfo } from "@roo-code/types" import { useAppTranslation } from "@src/i18n/TranslationContext" import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" +import { ContextWindow } from "@src/components/common/ContextWindow" import { inputEventTransform } from "../transforms" @@ -27,6 +28,42 @@ export const Groq = ({ apiConfiguration, setApiConfigurationField }: GroqProps) [setApiConfigurationField], ) + const handleContextWindowChange = useCallback( + (contextWindow: number | undefined) => { + const currentModelInfo = apiConfiguration?.groqCustomModelInfo + const updatedModelInfo: ModelInfo | undefined = contextWindow + ? { + maxTokens: currentModelInfo?.maxTokens ?? null, + contextWindow, + supportsPromptCache: currentModelInfo?.supportsPromptCache ?? false, + // Preserve other fields if they exist + ...(currentModelInfo && { + maxThinkingTokens: currentModelInfo.maxThinkingTokens, + supportsImages: currentModelInfo.supportsImages, + supportsComputerUse: currentModelInfo.supportsComputerUse, + supportsVerbosity: currentModelInfo.supportsVerbosity, + supportsReasoningBudget: currentModelInfo.supportsReasoningBudget, + requiredReasoningBudget: currentModelInfo.requiredReasoningBudget, + supportsReasoningEffort: currentModelInfo.supportsReasoningEffort, + supportedParameters: currentModelInfo.supportedParameters, + inputPrice: currentModelInfo.inputPrice, + outputPrice: currentModelInfo.outputPrice, + cacheWritesPrice: currentModelInfo.cacheWritesPrice, + cacheReadsPrice: currentModelInfo.cacheReadsPrice, + description: currentModelInfo.description, + reasoningEffort: currentModelInfo.reasoningEffort, + minTokensPerCachePoint: currentModelInfo.minTokensPerCachePoint, + maxCachePoints: currentModelInfo.maxCachePoints, + cachableFields: currentModelInfo.cachableFields, + tiers: currentModelInfo.tiers, + }), + } + : undefined + setApiConfigurationField("groqCustomModelInfo", updatedModelInfo) + }, + [apiConfiguration?.groqCustomModelInfo, setApiConfigurationField], + ) + return ( <> )} + ) } diff --git a/webview-ui/src/components/settings/providers/Vertex.tsx b/webview-ui/src/components/settings/providers/Vertex.tsx index 19a136927a..ae9cde7d76 100644 --- a/webview-ui/src/components/settings/providers/Vertex.tsx +++ b/webview-ui/src/components/settings/providers/Vertex.tsx @@ -1,10 +1,11 @@ import { useCallback } from "react" import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" -import { type ProviderSettings, VERTEX_REGIONS } from "@roo-code/types" +import { type ProviderSettings, type ModelInfo, VERTEX_REGIONS } from "@roo-code/types" import { useAppTranslation } from "@src/i18n/TranslationContext" import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@src/components/ui" +import { ContextWindow } from "@src/components/common/ContextWindow" import { inputEventTransform } from "../transforms" @@ -27,6 +28,42 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro [setApiConfigurationField], ) + const handleContextWindowChange = useCallback( + (contextWindow: number | undefined) => { + const currentModelInfo = apiConfiguration?.vertexCustomModelInfo + const updatedModelInfo: ModelInfo | undefined = contextWindow + ? { + maxTokens: currentModelInfo?.maxTokens ?? null, + contextWindow, + supportsPromptCache: currentModelInfo?.supportsPromptCache ?? false, + // Preserve other fields if they exist + ...(currentModelInfo && { + maxThinkingTokens: currentModelInfo.maxThinkingTokens, + supportsImages: currentModelInfo.supportsImages, + supportsComputerUse: currentModelInfo.supportsComputerUse, + supportsVerbosity: currentModelInfo.supportsVerbosity, + supportsReasoningBudget: currentModelInfo.supportsReasoningBudget, + requiredReasoningBudget: currentModelInfo.requiredReasoningBudget, + supportsReasoningEffort: currentModelInfo.supportsReasoningEffort, + supportedParameters: currentModelInfo.supportedParameters, + inputPrice: currentModelInfo.inputPrice, + outputPrice: currentModelInfo.outputPrice, + cacheWritesPrice: currentModelInfo.cacheWritesPrice, + cacheReadsPrice: currentModelInfo.cacheReadsPrice, + description: currentModelInfo.description, + reasoningEffort: currentModelInfo.reasoningEffort, + minTokensPerCachePoint: currentModelInfo.minTokensPerCachePoint, + maxCachePoints: currentModelInfo.maxCachePoints, + cachableFields: currentModelInfo.cachableFields, + tiers: currentModelInfo.tiers, + }), + } + : undefined + setApiConfigurationField("vertexCustomModelInfo", updatedModelInfo) + }, + [apiConfiguration?.vertexCustomModelInfo, setApiConfigurationField], + ) + return ( <>
@@ -91,6 +128,11 @@ export const Vertex = ({ apiConfiguration, setApiConfigurationField }: VertexPro
+ ) }