diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 207c60a524..1519b6a66b 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -36,6 +36,7 @@ export const providerNames = [ "huggingface", "cerebras", "sambanova", + "fireworks", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -257,6 +258,10 @@ const sambaNovaSchema = apiModelIdProviderModelSchema.extend({ sambaNovaApiKey: z.string().optional(), }) +const fireworksSchema = apiModelIdProviderModelSchema.extend({ + fireworksApiKey: z.string().optional(), +}) + const defaultSchema = z.object({ apiProvider: z.undefined(), }) @@ -290,6 +295,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), cerebrasSchema.merge(z.object({ apiProvider: z.literal("cerebras") })), sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })), + fireworksSchema.merge(z.object({ apiProvider: z.literal("fireworks") })), defaultSchema, ]) @@ -323,6 +329,7 @@ export const providerSettingsSchema = z.object({ ...litellmSchema.shape, ...cerebrasSchema.shape, ...sambaNovaSchema.shape, + ...fireworksSchema.shape, ...codebaseIndexProviderSchema.shape, }) diff --git a/packages/types/src/providers/fireworks.ts b/packages/types/src/providers/fireworks.ts new file mode 100644 index 0000000000..28765fa079 --- /dev/null +++ b/packages/types/src/providers/fireworks.ts @@ -0,0 +1,162 @@ +import type { ModelInfo } from "../model.js" + +// https://docs.fireworks.ai/models/overview +export type FireworksModelId = + | "accounts/fireworks/models/llama-v3p3-70b-instruct" + | "accounts/fireworks/models/llama-v3p2-11b-vision-instruct" + | "accounts/fireworks/models/llama-v3p2-90b-vision-instruct" + | "accounts/fireworks/models/llama-v3p1-405b-instruct" + | "accounts/fireworks/models/llama-v3p1-70b-instruct" + | "accounts/fireworks/models/llama-v3p1-8b-instruct" + | "accounts/fireworks/models/qwen2p5-72b-instruct" + | "accounts/fireworks/models/qwen2p5-32b-instruct" + | "accounts/fireworks/models/qwen2p5-14b-instruct" + | "accounts/fireworks/models/qwen2p5-7b-instruct" + | "accounts/fireworks/models/qwen2p5-3b-instruct" + | "accounts/fireworks/models/qwen2p5-1p5b-instruct" + | "accounts/fireworks/models/qwen2p5-0p5b-instruct" + | "accounts/fireworks/models/qwen2p5-coder-32b-instruct" + | "accounts/moonshot/models/moonshot-v1-auto" + +export const fireworksDefaultModelId: FireworksModelId = "accounts/fireworks/models/llama-v3p3-70b-instruct" + +export const fireworksModels = { + // Llama models + "accounts/fireworks/models/llama-v3p3-70b-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: "Meta Llama 3.3 70B Instruct model with 128K context window", + }, + "accounts/fireworks/models/llama-v3p2-11b-vision-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.2, + description: "Meta Llama 3.2 11B Vision Instruct model with multimodal capabilities", + }, + "accounts/fireworks/models/llama-v3p2-90b-vision-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 3.0, + outputPrice: 3.0, + description: "Meta Llama 3.2 90B Vision Instruct model with multimodal capabilities", + }, + "accounts/fireworks/models/llama-v3p1-405b-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 3.0, + outputPrice: 3.0, + description: "Meta Llama 3.1 405B Instruct model, largest Llama model", + }, + "accounts/fireworks/models/llama-v3p1-70b-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: "Meta Llama 3.1 70B Instruct model with 128K context window", + }, + "accounts/fireworks/models/llama-v3p1-8b-instruct": { + maxTokens: 16384, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.2, + description: "Meta Llama 3.1 8B Instruct model, efficient and fast", + }, + // Qwen models + "accounts/fireworks/models/qwen2p5-72b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: "Alibaba Qwen 2.5 72B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-32b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: "Alibaba Qwen 2.5 32B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-14b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.2, + description: "Alibaba Qwen 2.5 14B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-7b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.2, + description: "Alibaba Qwen 2.5 7B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-3b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.1, + description: "Alibaba Qwen 2.5 3B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-1p5b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.1, + description: "Alibaba Qwen 2.5 1.5B Instruct model", + }, + "accounts/fireworks/models/qwen2p5-0p5b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.1, + description: "Alibaba Qwen 2.5 0.5B Instruct model, smallest Qwen model", + }, + "accounts/fireworks/models/qwen2p5-coder-32b-instruct": { + maxTokens: 32768, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.9, + outputPrice: 0.9, + description: "Alibaba Qwen 2.5 Coder 32B Instruct model, optimized for code generation", + }, + // Moonshot models + "accounts/moonshot/models/moonshot-v1-auto": { + maxTokens: 65536, + contextWindow: 1000000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 1.0, + outputPrice: 1.0, + description: "Moonshot Kimi model with up to 1M context window", + }, +} as const satisfies Record diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index d6584e70ec..6ae371a957 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -4,6 +4,8 @@ export * from "./cerebras.js" export * from "./chutes.js" export * from "./claude-code.js" export * from "./deepseek.js" +export * from "./doubao.js" +export * from "./fireworks.js" export * from "./gemini.js" export * from "./glama.js" export * from "./groq.js" @@ -21,4 +23,3 @@ export * from "./unbound.js" export * from "./vertex.js" export * from "./vscode-llm.js" export * from "./xai.js" -export * from "./doubao.js" diff --git a/src/api/index.ts b/src/api/index.ts index 5daa53396f..1fc12c5042 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -33,6 +33,7 @@ import { ClaudeCodeHandler, SambaNovaHandler, DoubaoHandler, + FireworksHandler, } from "./providers" export interface SingleCompletionHandler { @@ -124,6 +125,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new CerebrasHandler(options) case "sambanova": return new SambaNovaHandler(options) + case "fireworks": + return new FireworksHandler(options) default: apiProvider satisfies "gemini-cli" | undefined return new AnthropicHandler(options) diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts new file mode 100644 index 0000000000..b62431ebfe --- /dev/null +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -0,0 +1,204 @@ +// npx vitest run src/api/providers/__tests__/fireworks.spec.ts + +// Mock vscode first to avoid import errors +vitest.mock("vscode", () => ({})) + +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" + +import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" + +import { FireworksHandler } from "../fireworks" + +vitest.mock("openai", () => { + const createMock = vitest.fn() + return { + default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })), + } +}) + +describe("FireworksHandler", () => { + let handler: FireworksHandler + let mockCreate: any + + beforeEach(() => { + vitest.clearAllMocks() + mockCreate = (OpenAI as unknown as any)().chat.completions.create + handler = new FireworksHandler({ fireworksApiKey: "test-fireworks-api-key" }) + }) + + it("should use the correct Fireworks base URL", () => { + new FireworksHandler({ fireworksApiKey: "test-fireworks-api-key" }) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ baseURL: "https://api.fireworks.ai/inference/v1" }), + ) + }) + + it("should use the provided API key", () => { + const fireworksApiKey = "test-fireworks-api-key" + new FireworksHandler({ fireworksApiKey }) + expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: fireworksApiKey })) + }) + + it("should return default model when no model is specified", () => { + const model = handler.getModel() + expect(model.id).toBe(fireworksDefaultModelId) + expect(model.info).toEqual(fireworksModels[fireworksDefaultModelId]) + }) + + it("should return specified model when valid model is provided", () => { + const testModelId: FireworksModelId = "accounts/fireworks/models/llama-v3p1-70b-instruct" + const handlerWithModel = new FireworksHandler({ + apiModelId: testModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(fireworksModels[testModelId]) + }) + + it("completePrompt method should return text from Fireworks API", async () => { + const expectedResponse = "This is a test response from Fireworks" + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe(expectedResponse) + }) + + it("should handle errors in completePrompt", async () => { + const errorMessage = "Fireworks API error" + mockCreate.mockRejectedValueOnce(new Error(errorMessage)) + await expect(handler.completePrompt("test prompt")).rejects.toThrow( + `Fireworks completion error: ${errorMessage}`, + ) + }) + + it("createMessage should yield text content from stream", async () => { + const testContent = "This is test content from Fireworks stream" + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vitest + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: { content: testContent } }] }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ type: "text", text: testContent }) + }) + + it("createMessage should yield usage data from stream", async () => { + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + next: vitest + .fn() + .mockResolvedValueOnce({ + done: false, + value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, + }) + .mockResolvedValueOnce({ done: true }), + }), + } + }) + + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() + + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + }) + + it("createMessage should pass correct parameters to Fireworks client", async () => { + const modelId: FireworksModelId = "accounts/fireworks/models/llama-v3p1-8b-instruct" + const modelInfo = fireworksModels[modelId] + const handlerWithModel = new FireworksHandler({ + apiModelId: modelId, + fireworksApiKey: "test-fireworks-api-key", + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const systemPrompt = "Test system prompt for Fireworks" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Fireworks" }] + + const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: modelId, + max_tokens: modelInfo.maxTokens, + temperature: 0.7, + messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), + stream: true, + stream_options: { include_usage: true }, + }), + ) + }) + + it("should support vision models with image content", async () => { + const visionModelId: FireworksModelId = "accounts/fireworks/models/llama-v3p2-11b-vision-instruct" + const handlerWithVisionModel = new FireworksHandler({ + apiModelId: visionModelId, + fireworksApiKey: "test-fireworks-api-key", + }) + + mockCreate.mockImplementationOnce(() => { + return { + [Symbol.asyncIterator]: () => ({ + async next() { + return { done: true } + }, + }), + } + }) + + const systemPrompt = "Test system prompt for vision model" + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "What's in this image?" }, + { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "base64data" } }, + ], + }, + ] + + const messageGenerator = handlerWithVisionModel.createMessage(systemPrompt, messages) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: visionModelId, + messages: expect.arrayContaining([ + { role: "system", content: systemPrompt }, + { + role: "user", + content: expect.arrayContaining([ + { type: "text", text: "What's in this image?" }, + { type: "image_url", image_url: { url: "" } }, + ]), + }, + ]), + }), + ) + }) +}) diff --git a/src/api/providers/fireworks.ts b/src/api/providers/fireworks.ts new file mode 100644 index 0000000000..891437d803 --- /dev/null +++ b/src/api/providers/fireworks.ts @@ -0,0 +1,19 @@ +import { type FireworksModelId, fireworksDefaultModelId, fireworksModels } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" + +import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" + +export class FireworksHandler extends BaseOpenAiCompatibleProvider { + constructor(options: ApiHandlerOptions) { + super({ + ...options, + providerName: "Fireworks", + baseURL: "https://api.fireworks.ai/inference/v1", + apiKey: options.fireworksApiKey, + defaultProviderModelId: fireworksDefaultModelId, + providerModels: fireworksModels, + defaultTemperature: 0.7, + }) + } +} diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index a1b8f25536..88bbde43f2 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -8,6 +8,7 @@ export { DeepSeekHandler } from "./deepseek" export { DoubaoHandler } from "./doubao" export { MoonshotHandler } from "./moonshot" export { FakeAIHandler } from "./fake-ai" +export { FireworksHandler } from "./fireworks" export { GeminiHandler } from "./gemini" export { GlamaHandler } from "./glama" export { GroqHandler } from "./groq" diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index d70ca553ac..30248709fa 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -28,6 +28,7 @@ import { bedrockDefaultModelId, vertexDefaultModelId, sambaNovaDefaultModelId, + fireworksDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" @@ -61,6 +62,7 @@ import { ClaudeCode, DeepSeek, Doubao, + Fireworks, Gemini, Glama, Groq, @@ -306,6 +308,7 @@ const ApiOptions = ({ bedrock: { field: "apiModelId", default: bedrockDefaultModelId }, vertex: { field: "apiModelId", default: vertexDefaultModelId }, sambanova: { field: "apiModelId", default: sambaNovaDefaultModelId }, + fireworks: { field: "apiModelId", default: fireworksDefaultModelId }, openai: { field: "openAiModelId" }, ollama: { field: "ollamaModelId" }, lmstudio: { field: "lmStudioModelId" }, @@ -530,6 +533,10 @@ const ApiOptions = ({ )} + {selectedProvider === "fireworks" && ( + + )} + {selectedProvider === "human-relay" && ( <>
diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index fae35b1693..7e9ff8a478 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -16,6 +16,7 @@ import { chutesModels, sambaNovaModels, doubaoModels, + fireworksModels, } from "@roo-code/types" export const MODELS_BY_PROVIDER: Partial>> = { @@ -34,6 +35,7 @@ export const MODELS_BY_PROVIDER: Partial a.label.localeCompare(b.label)) diff --git a/webview-ui/src/components/settings/providers/Fireworks.tsx b/webview-ui/src/components/settings/providers/Fireworks.tsx new file mode 100644 index 0000000000..834ce9583a --- /dev/null +++ b/webview-ui/src/components/settings/providers/Fireworks.tsx @@ -0,0 +1,50 @@ +import { useCallback } from "react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import type { ProviderSettings } from "@roo-code/types" + +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" + +import { inputEventTransform } from "../transforms" + +type FireworksProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void +} + +export const Fireworks = ({ apiConfiguration, setApiConfigurationField }: FireworksProps) => { + const { t } = useAppTranslation() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + return ( + <> + + + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ {!apiConfiguration?.fireworksApiKey && ( + + {t("settings:providers.getFireworksApiKey")} + + )} + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index 47430a0cc8..73b16fb8fd 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -5,6 +5,7 @@ export { Chutes } from "./Chutes" export { ClaudeCode } from "./ClaudeCode" export { DeepSeek } from "./DeepSeek" export { Doubao } from "./Doubao" +export { Fireworks } from "./Fireworks" export { Gemini } from "./Gemini" export { Glama } from "./Glama" export { Groq } from "./Groq" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 0c6a84a65e..a041890dec 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -40,6 +40,8 @@ import { sambaNovaDefaultModelId, doubaoModels, doubaoDefaultModelId, + fireworksModels, + fireworksDefaultModelId, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" @@ -258,6 +260,11 @@ function getSelectedModel({ const info = sambaNovaModels[id as keyof typeof sambaNovaModels] return { id, info } } + case "fireworks": { + const id = apiConfiguration.apiModelId ?? fireworksDefaultModelId + const info = fireworksModels[id as keyof typeof fireworksModels] + return { id, info } + } // case "anthropic": // case "human-relay": // case "fake-ai": diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index 46c15556c8..0c29486a36 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -272,6 +272,8 @@ "groqApiKey": "Groq API Key", "getSambaNovaApiKey": "Get SambaNova API Key", "sambaNovaApiKey": "SambaNova API Key", + "getFireworksApiKey": "Get Fireworks API Key", + "fireworksApiKey": "Fireworks API Key", "getHuggingFaceApiKey": "Get Hugging Face API Key", "huggingFaceApiKey": "Hugging Face API Key", "huggingFaceModelId": "Model ID",