-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: add custom model context window override for all providers #8398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,126 @@ | ||||||
| import { describe, it, expect, beforeEach } from "vitest" | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P2] beforeEach is imported but unused; remove to keep tests clean.
Suggested change
|
||||||
| import { AnthropicHandler } from "../anthropic" | ||||||
| import { OpenRouterHandler } from "../openrouter" | ||||||
| import { OpenAiHandler } from "../openai" | ||||||
| import { GeminiHandler } from "../gemini" | ||||||
| import type { ApiHandlerOptions } from "../../../shared/api" | ||||||
|
|
||||||
| describe("Context Window Override", () => { | ||||||
| describe("AnthropicHandler", () => { | ||||||
| it("should apply modelContextWindow override", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| apiKey: "test-key", | ||||||
| apiModelId: "claude-3-5-sonnet-20241022", | ||||||
| modelContextWindow: 50000, // Custom context window | ||||||
| } | ||||||
|
|
||||||
| const handler = new AnthropicHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| expect(model.info.contextWindow).toBe(50000) | ||||||
| }) | ||||||
|
|
||||||
| it("should use default context window when no override is provided", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| apiKey: "test-key", | ||||||
| apiModelId: "claude-3-5-sonnet-20241022", | ||||||
| } | ||||||
|
|
||||||
| const handler = new AnthropicHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| // Should use the default context window for this model | ||||||
| expect(model.info.contextWindow).toBe(200000) | ||||||
| }) | ||||||
| }) | ||||||
|
|
||||||
| describe("OpenRouterHandler", () => { | ||||||
| it("should apply modelContextWindow override", async () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| openRouterApiKey: "test-key", | ||||||
| openRouterModelId: "anthropic/claude-3.5-sonnet", | ||||||
| modelContextWindow: 75000, // Custom context window | ||||||
| } | ||||||
|
|
||||||
| const handler = new OpenRouterHandler(options) | ||||||
| // Mock the models to avoid actual API calls | ||||||
| ;(handler as any).models = { | ||||||
| "anthropic/claude-3.5-sonnet": { | ||||||
| contextWindow: 200000, | ||||||
| maxTokens: 8192, | ||||||
| supportsPromptCache: true, | ||||||
| supportsImages: true, | ||||||
| }, | ||||||
| } | ||||||
|
|
||||||
| const model = handler.getModel() | ||||||
| expect(model.info.contextWindow).toBe(75000) | ||||||
| }) | ||||||
| }) | ||||||
|
|
||||||
| describe("OpenAiHandler", () => { | ||||||
| it("should apply modelContextWindow override to custom model info", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| openAiApiKey: "test-key", | ||||||
| openAiModelId: "gpt-4", | ||||||
| openAiCustomModelInfo: { | ||||||
| contextWindow: 128000, | ||||||
| maxTokens: 4096, | ||||||
| supportsPromptCache: false, | ||||||
| supportsImages: true, | ||||||
| }, | ||||||
| modelContextWindow: 60000, // Custom context window | ||||||
| } | ||||||
|
|
||||||
| const handler = new OpenAiHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| expect(model.info.contextWindow).toBe(60000) | ||||||
| }) | ||||||
| }) | ||||||
|
|
||||||
| describe("GeminiHandler", () => { | ||||||
| it("should apply modelContextWindow override", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| geminiApiKey: "test-key", | ||||||
| apiModelId: "gemini-1.5-pro-latest", | ||||||
| modelContextWindow: 100000, // Custom context window | ||||||
| } | ||||||
|
|
||||||
| const handler = new GeminiHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| expect(model.info.contextWindow).toBe(100000) | ||||||
| }) | ||||||
| }) | ||||||
|
|
||||||
| describe("Edge cases", () => { | ||||||
| it("should not apply override when modelContextWindow is 0", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| apiKey: "test-key", | ||||||
| apiModelId: "claude-3-5-sonnet-20241022", | ||||||
| modelContextWindow: 0, // Zero should not override | ||||||
| } | ||||||
|
|
||||||
| const handler = new AnthropicHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| // Should use the default context window | ||||||
| expect(model.info.contextWindow).toBe(200000) | ||||||
| }) | ||||||
|
|
||||||
| it("should not apply override when modelContextWindow is negative", () => { | ||||||
| const options: ApiHandlerOptions = { | ||||||
| apiKey: "test-key", | ||||||
| apiModelId: "claude-3-5-sonnet-20241022", | ||||||
| modelContextWindow: -1000, // Negative should not override | ||||||
| } | ||||||
|
|
||||||
| const handler = new AnthropicHandler(options) | ||||||
| const model = handler.getModel() | ||||||
|
|
||||||
| // Should use the default context window | ||||||
| expect(model.info.contextWindow).toBe(200000) | ||||||
| }) | ||||||
| }) | ||||||
| }) | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P2] Consider adding cases for AwsBedrockHandler and RouterProvider to guard the new override paths you added in those files (and any other provider base classes used). This will help catch regressions if provider wiring changes. |
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,6 +3,7 @@ import { Anthropic } from "@anthropic-ai/sdk" | |||||||||||||||||||||||||||||||||||||||
| import type { ModelInfo } from "@roo-code/types" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import type { ApiHandler, ApiHandlerCreateMessageMetadata } from "../index" | ||||||||||||||||||||||||||||||||||||||||
| import type { ApiHandlerOptions } from "../../shared/api" | ||||||||||||||||||||||||||||||||||||||||
| import { ApiStream } from "../transform/stream" | ||||||||||||||||||||||||||||||||||||||||
| import { countTokens } from "../../utils/countTokens" | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -18,6 +19,26 @@ export abstract class BaseProvider implements ApiHandler { | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| abstract getModel(): { id: string; info: ModelInfo } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||
| * Applies user-configured overrides to model info. | ||||||||||||||||||||||||||||||||||||||||
| * This allows users to customize model parameters like context window size | ||||||||||||||||||||||||||||||||||||||||
| * to work around corporate restrictions or other limitations. | ||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||
| * @param info The original model info | ||||||||||||||||||||||||||||||||||||||||
| * @param options The API handler options containing user overrides | ||||||||||||||||||||||||||||||||||||||||
| * @returns The model info with overrides applied | ||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||
| protected applyModelOverrides(info: ModelInfo, options: ApiHandlerOptions): ModelInfo { | ||||||||||||||||||||||||||||||||||||||||
| const overriddenInfo = { ...info } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| // Apply context window override if specified | ||||||||||||||||||||||||||||||||||||||||
| if (options.modelContextWindow && options.modelContextWindow > 0) { | ||||||||||||||||||||||||||||||||||||||||
| overriddenInfo.contextWindow = options.modelContextWindow | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P0] If modelContextWindow is smaller than the model's maxTokens, downstream calls may pass an invalid max_tokens > contextWindow to providers. Suggest clamping maxTokens to contextWindow after applying the override to maintain the invariant maxTokens ≤ contextWindow.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| return overriddenInfo | ||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||
| * Default token counting implementation using tiktoken. | ||||||||||||||||||||||||||||||||||||||||
| * Providers can override this to use their native token counting endpoints. | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Schema allows floats/negatives; tighten to non-negative integers to reflect expected semantics and avoid silent float inputs.