Skip to content

Commit 4a230ad

Browse files
frostbournesbmattappersonellipsis-dev[bot]dcbartlett
authored
Add Fireworks API Provider (RooCodeInc#3496)
* initial * finishing touches * Update webview-ui/src/utils/validate.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update webview-ui/src/components/settings/ApiOptions.tsx Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * requested changes * fix url * fix vars * Update webview-ui/src/components/chat/ChatTextArea.tsx Co-authored-by: Dennis Bartlett <[email protected]> * Update fireworks API link * Improve margins --------- Co-authored-by: Matt Apperson <[email protected]> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> Co-authored-by: Dennis Bartlett <[email protected]>
1 parent 94c432f commit 4a230ad

File tree

10 files changed

+291
-0
lines changed

10 files changed

+291
-0
lines changed

.changeset/fuzzy-ducks-flow.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"claude-dev": minor
3+
---
4+
5+
Add Fireworks API Provider

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { DoubaoHandler } from "./providers/doubao"
1919
import { VsCodeLmHandler } from "./providers/vscode-lm"
2020
import { ClineHandler } from "./providers/cline"
2121
import { LiteLlmHandler } from "./providers/litellm"
22+
import { FireworksHandler } from "./providers/fireworks"
2223
import { AskSageHandler } from "./providers/asksage"
2324
import { XAIHandler } from "./providers/xai"
2425
import { SambanovaHandler } from "./providers/sambanova"
@@ -58,6 +59,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
5859
return new DeepSeekHandler(options)
5960
case "requesty":
6061
return new RequestyHandler(options)
62+
case "fireworks":
63+
return new FireworksHandler(options)
6164
case "together":
6265
return new TogetherHandler(options)
6366
case "qwen":

src/api/providers/fireworks.ts

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import OpenAI from "openai"
3+
import { withRetry } from "../retry"
4+
import { ApiHandler } from ".."
5+
import {
6+
ApiHandlerOptions,
7+
DeepSeekModelId,
8+
ModelInfo,
9+
deepSeekDefaultModelId,
10+
deepSeekModels,
11+
openAiModelInfoSaneDefaults,
12+
} from "../../shared/api"
13+
import { convertToOpenAiMessages } from "../transform/openai-format"
14+
import { ApiStream } from "../transform/stream"
15+
16+
export class FireworksHandler implements ApiHandler {
17+
private options: ApiHandlerOptions
18+
private client: OpenAI
19+
20+
constructor(options: ApiHandlerOptions) {
21+
this.options = options
22+
this.client = new OpenAI({
23+
baseURL: "https://api.fireworks.ai/inference/v1",
24+
apiKey: this.options.fireworksApiKey,
25+
})
26+
}
27+
28+
@withRetry()
29+
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
30+
const modelId = this.options.fireworksModelId ?? ""
31+
32+
const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
33+
{ role: "system", content: systemPrompt },
34+
...convertToOpenAiMessages(messages),
35+
]
36+
37+
const stream = await this.client.chat.completions.create({
38+
model: modelId,
39+
...(this.options.fireworksModelMaxCompletionTokens
40+
? { max_completion_tokens: this.options.fireworksModelMaxCompletionTokens }
41+
: {}),
42+
...(this.options.fireworksModelMaxTokens ? { max_tokens: this.options.fireworksModelMaxTokens } : {}),
43+
messages: openAiMessages,
44+
stream: true,
45+
stream_options: { include_usage: true },
46+
temperature: 0,
47+
})
48+
49+
let reasoning: string | null = null
50+
for await (const chunk of stream) {
51+
const delta = chunk.choices[0]?.delta
52+
if (reasoning || delta?.content?.includes("<think>")) {
53+
reasoning = (reasoning || "") + (delta.content ?? "")
54+
}
55+
56+
if (delta?.content && !reasoning) {
57+
yield {
58+
type: "text",
59+
text: delta.content,
60+
}
61+
}
62+
63+
if (reasoning || ("reasoning_content" in delta && delta.reasoning_content)) {
64+
yield {
65+
type: "reasoning",
66+
reasoning: delta.content || ((delta as any).reasoning_content as string | undefined) || "",
67+
}
68+
if (reasoning?.includes("</think>")) {
69+
// Reset so the next chunk is regular content
70+
reasoning = null
71+
}
72+
}
73+
74+
if (chunk.usage) {
75+
yield {
76+
type: "usage",
77+
inputTokens: chunk.usage.prompt_tokens || 0, // (deepseek reports total input AND cache reads/writes, see context caching: https://api-docs.deepseek.com/guides/kv_cache) where the input tokens is the sum of the cache hits/misses, while anthropic reports them as separate tokens. This is important to know for 1) context management truncation algorithm, and 2) cost calculation (NOTE: we report both input and cache stats but for now set input price to 0 since all the cost calculation will be done using cache hits/misses)
78+
outputTokens: chunk.usage.completion_tokens || 0,
79+
// @ts-ignore-next-line
80+
cacheReadTokens: chunk.usage.prompt_cache_hit_tokens || 0,
81+
// @ts-ignore-next-line
82+
cacheWriteTokens: chunk.usage.prompt_cache_miss_tokens || 0,
83+
}
84+
}
85+
}
86+
}
87+
88+
getModel(): { id: string; info: ModelInfo } {
89+
return {
90+
id: this.options.fireworksModelId ?? "",
91+
info: openAiModelInfoSaneDefaults,
92+
}
93+
}
94+
}

src/core/storage/state-keys.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export type SecretKey =
1111
| "deepSeekApiKey"
1212
| "requestyApiKey"
1313
| "togetherApiKey"
14+
| "fireworksApiKey"
1415
| "qwenApiKey"
1516
| "doubaoApiKey"
1617
| "mistralApiKey"
@@ -69,6 +70,9 @@ export type GlobalStateKey =
6970
| "liteLlmModelId"
7071
| "liteLlmModelInfo"
7172
| "liteLlmUsePromptCache"
73+
| "fireworksModelId"
74+
| "fireworksModelMaxCompletionTokens"
75+
| "fireworksModelMaxTokens"
7276
| "qwenApiLine"
7377
| "requestyModelId"
7478
| "requestyModelInfo"

src/core/storage/state.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
109109
liteLlmModelId,
110110
liteLlmModelInfo,
111111
liteLlmUsePromptCache,
112+
fireworksApiKey,
113+
fireworksModelId,
114+
fireworksModelMaxCompletionTokens,
115+
fireworksModelMaxTokens,
112116
userInfo,
113117
previousModeApiProvider,
114118
previousModeModelId,
@@ -189,6 +193,10 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
189193
getGlobalState(context, "liteLlmModelId") as Promise<string | undefined>,
190194
getGlobalState(context, "liteLlmModelInfo") as Promise<ModelInfo | undefined>,
191195
getGlobalState(context, "liteLlmUsePromptCache") as Promise<boolean | undefined>,
196+
getSecret(context, "fireworksApiKey") as Promise<string | undefined>,
197+
getGlobalState(context, "fireworksModelId") as Promise<string | undefined>,
198+
getGlobalState(context, "fireworksModelMaxCompletionTokens") as Promise<number | undefined>,
199+
getGlobalState(context, "fireworksModelMaxTokens") as Promise<number | undefined>,
192200
getGlobalState(context, "userInfo") as Promise<UserInfo | undefined>,
193201
getGlobalState(context, "previousModeApiProvider") as Promise<ApiProvider | undefined>,
194202
getGlobalState(context, "previousModeModelId") as Promise<string | undefined>,
@@ -309,6 +317,10 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
309317
liteLlmModelInfo,
310318
liteLlmApiKey,
311319
liteLlmUsePromptCache,
320+
fireworksApiKey,
321+
fireworksModelId,
322+
fireworksModelMaxCompletionTokens,
323+
fireworksModelMaxTokens,
312324
asksageApiKey,
313325
asksageApiUrl,
314326
xaiApiKey,
@@ -485,6 +497,7 @@ export async function resetExtensionState(context: vscode.ExtensionContext) {
485497
"mistralApiKey",
486498
"clineApiKey",
487499
"liteLlmApiKey",
500+
"fireworksApiKey",
488501
"asksageApiKey",
489502
"xaiApiKey",
490503
"sambanovaApiKey",

src/shared/api.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export type ApiProvider =
1919
| "vscode-lm"
2020
| "cline"
2121
| "litellm"
22+
| "fireworks"
2223
| "asksage"
2324
| "xai"
2425
| "sambanova"
@@ -70,6 +71,10 @@ export interface ApiHandlerOptions {
7071
requestyModelInfo?: ModelInfo
7172
togetherApiKey?: string
7273
togetherModelId?: string
74+
fireworksApiKey?: string
75+
fireworksModelId?: string
76+
fireworksModelMaxCompletionTokens?: number
77+
fireworksModelMaxTokens?: number
7378
qwenApiKey?: string
7479
doubaoApiKey?: string
7580
mistralApiKey?: string

webview-ui/src/components/chat/ChatTextArea.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,8 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
10361036
return `vscode-lm:${apiConfiguration.vsCodeLmModelSelector ? `${apiConfiguration.vsCodeLmModelSelector.vendor ?? ""}/${apiConfiguration.vsCodeLmModelSelector.family ?? ""}` : unknownModel}`
10371037
case "together":
10381038
return `${selectedProvider}:${apiConfiguration.togetherModelId}`
1039+
case "fireworks":
1040+
return `fireworks:${apiConfiguration.fireworksModelId}`
10391041
case "lmstudio":
10401042
return `${selectedProvider}:${apiConfiguration.lmStudioModelId}`
10411043
case "ollama":

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ const ApiOptions = ({
316316
<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
317317
<VSCodeOption value="vscode-lm">VS Code LM API</VSCodeOption>
318318
<VSCodeOption value="requesty">Requesty</VSCodeOption>
319+
<VSCodeOption value="fireworks">Fireworks</VSCodeOption>
319320
<VSCodeOption value="together">Together</VSCodeOption>
320321
<VSCodeOption value="qwen">Alibaba Qwen</VSCodeOption>
321322
<VSCodeOption value="doubao">Bytedance Doubao</VSCodeOption>
@@ -1370,6 +1371,97 @@ const ApiOptions = ({
13701371
</div>
13711372
)}
13721373

1374+
{selectedProvider === "fireworks" && (
1375+
<div>
1376+
<VSCodeTextField
1377+
value={apiConfiguration?.fireworksApiKey || ""}
1378+
style={{ width: "100%" }}
1379+
type="password"
1380+
onInput={handleInputChange("fireworksApiKey")}
1381+
placeholder="Enter API Key...">
1382+
<span style={{ fontWeight: 500 }}>Fireworks API Key</span>
1383+
</VSCodeTextField>
1384+
<p
1385+
style={{
1386+
fontSize: "12px",
1387+
marginTop: 3,
1388+
color: "var(--vscode-descriptionForeground)",
1389+
}}>
1390+
This key is stored locally and only used to make API requests from this extension.
1391+
{!apiConfiguration?.fireworksApiKey && (
1392+
<VSCodeLink
1393+
href="https://fireworks.ai/settings/users/api-keys"
1394+
style={{
1395+
display: "inline",
1396+
fontSize: "inherit",
1397+
}}>
1398+
You can get a Fireworks API key by signing up here.
1399+
</VSCodeLink>
1400+
)}
1401+
</p>
1402+
<VSCodeTextField
1403+
value={apiConfiguration?.fireworksModelId || ""}
1404+
style={{ width: "100%" }}
1405+
onInput={handleInputChange("fireworksModelId")}
1406+
placeholder={"Enter Model ID..."}>
1407+
<span style={{ fontWeight: 500 }}>Model ID</span>
1408+
</VSCodeTextField>
1409+
<p
1410+
style={{
1411+
fontSize: "12px",
1412+
marginTop: 3,
1413+
color: "var(--vscode-descriptionForeground)",
1414+
}}>
1415+
<span style={{ color: "var(--vscode-errorForeground)" }}>
1416+
(<span style={{ fontWeight: 500 }}>Note:</span> Cline uses complex prompts and works best with Claude
1417+
models. Less capable models may not work as expected.)
1418+
</span>
1419+
</p>
1420+
<VSCodeTextField
1421+
value={apiConfiguration?.fireworksModelMaxCompletionTokens?.toString() || ""}
1422+
style={{ width: "100%", marginBottom: 8 }}
1423+
onInput={(e) => {
1424+
const value = (e.target as HTMLInputElement).value
1425+
if (!value) {
1426+
return
1427+
}
1428+
const num = parseInt(value, 10)
1429+
if (isNaN(num)) {
1430+
return
1431+
}
1432+
handleInputChange("fireworksModelMaxCompletionTokens")({
1433+
target: {
1434+
value: num,
1435+
},
1436+
})
1437+
}}
1438+
placeholder={"2000"}>
1439+
<span style={{ fontWeight: 500 }}>Max Completion Tokens</span>
1440+
</VSCodeTextField>
1441+
<VSCodeTextField
1442+
value={apiConfiguration?.fireworksModelMaxTokens?.toString() || ""}
1443+
style={{ width: "100%", marginBottom: 8 }}
1444+
onInput={(e) => {
1445+
const value = (e.target as HTMLInputElement).value
1446+
if (!value) {
1447+
return
1448+
}
1449+
const num = parseInt(value)
1450+
if (isNaN(num)) {
1451+
return
1452+
}
1453+
handleInputChange("fireworksModelMaxTokens")({
1454+
target: {
1455+
value: num,
1456+
},
1457+
})
1458+
}}
1459+
placeholder={"4000"}>
1460+
<span style={{ fontWeight: 500 }}>Max Context Tokens</span>
1461+
</VSCodeTextField>
1462+
</div>
1463+
)}
1464+
13731465
{selectedProvider === "together" && (
13741466
<div>
13751467
<VSCodeTextField

webview-ui/src/components/settings/__tests__/APIOptions.spec.tsx

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,74 @@ vi.mock("../../../context/ExtensionStateContext", async (importOriginal) => {
101101
const actual = await importOriginal()
102102
return {
103103
...(actual || {}),
104+
useExtensionState: vi.fn(() => ({
105+
apiConfiguration: {
106+
apiProvider: "fireworks",
107+
fireworksApiKey: "",
108+
fireworksModelId: "",
109+
fireworksModelMaxCompletionTokens: 2000,
110+
fireworksModelMaxTokens: 4000,
111+
},
112+
setApiConfiguration: vi.fn(),
113+
uriScheme: "vscode",
114+
})),
115+
}
116+
})
117+
118+
describe("ApiOptions Component", () => {
119+
vi.clearAllMocks()
120+
const mockPostMessage = vi.fn()
121+
122+
beforeEach(() => {
123+
global.vscode = { postMessage: mockPostMessage } as any
124+
})
125+
126+
it("renders Fireworks API Key input", () => {
127+
render(
128+
<ExtensionStateContextProvider>
129+
<ApiOptions showModelOptions={true} />
130+
</ExtensionStateContextProvider>,
131+
)
132+
const apiKeyInput = screen.getByPlaceholderText("Enter API Key...")
133+
expect(apiKeyInput).toBeInTheDocument()
134+
})
135+
136+
it("renders Fireworks Model ID input", () => {
137+
render(
138+
<ExtensionStateContextProvider>
139+
<ApiOptions showModelOptions={true} />
140+
</ExtensionStateContextProvider>,
141+
)
142+
const modelIdInput = screen.getByPlaceholderText("Enter Model ID...")
143+
expect(modelIdInput).toBeInTheDocument()
144+
})
145+
146+
it("renders Fireworks Max Completion Tokens input", () => {
147+
render(
148+
<ExtensionStateContextProvider>
149+
<ApiOptions showModelOptions={true} />
150+
</ExtensionStateContextProvider>,
151+
)
152+
const maxCompletionTokensInput = screen.getByPlaceholderText("2000")
153+
expect(maxCompletionTokensInput).toBeInTheDocument()
154+
})
155+
156+
it("renders Fireworks Max Tokens input", () => {
157+
render(
158+
<ExtensionStateContextProvider>
159+
<ApiOptions showModelOptions={true} />
160+
</ExtensionStateContextProvider>,
161+
)
162+
const maxTokensInput = screen.getByPlaceholderText("4000")
163+
expect(maxTokensInput).toBeInTheDocument()
164+
})
165+
})
166+
167+
vi.mock("../../../context/ExtensionStateContext", async (importOriginal) => {
168+
const actual = await importOriginal()
169+
return {
170+
...actual,
171+
// your mocked methods
104172
useExtensionState: vi.fn(() => ({
105173
apiConfiguration: {
106174
apiProvider: "openai",

webview-ui/src/utils/validate.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
7373
return "You must provide a valid API key or choose a different provider."
7474
}
7575
break
76+
case "fireworks":
77+
if (!apiConfiguration.fireworksApiKey || !apiConfiguration.fireworksModelId) {
78+
return "You must provide a valid API key or choose a different provider."
79+
}
80+
break
7681
case "together":
7782
if (!apiConfiguration.togetherApiKey || !apiConfiguration.togetherModelId) {
7883
return "You must provide a valid API key or choose a different provider."

0 commit comments

Comments
 (0)