-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Feat: Adding Gemini tools - URL Context and Grounding with Google Search #5959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 57 commits
5ff5993
afcb66d
ac96e99
121e243
a20774e
67b4762
9595f76
26b1f53
8f468f4
24e8ed5
98e813d
91c16cb
1497edd
a204169
83f02d5
c438277
f384f73
449a8c2
f8c04c9
645b2fc
7e5a59d
edb96c6
cae3de9
ae2e895
a5f46b4
bf01618
63c7b25
151601b
055bd79
3ff4c1e
4200cff
0d72f08
d18b143
22eb360
7e9d252
ae0a3b7
8d48fcc
d1386a5
8ca442e
a853329
1c2aa36
88a7eb4
847756c
3b8c801
728aded
9832b51
abf3f1d
a6e8408
49a66d8
3ab7f13
b7b78df
5f63d15
c537802
8f21918
5f8d0c2
5876976
8fad431
816f634
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import { describe, it, expect, vi } from "vitest" | ||
| import { GeminiHandler } from "../gemini" | ||
| import type { ApiHandlerOptions } from "../../../shared/api" | ||
|
|
||
| describe("GeminiHandler backend support", () => { | ||
| it("passes tools for URL context and grounding in config", async () => { | ||
| const options = { | ||
| apiProvider: "gemini", | ||
| enableUrlContext: true, | ||
| enableGrounding: true, | ||
| } as ApiHandlerOptions | ||
| const handler = new GeminiHandler(options) | ||
| const stub = vi.fn().mockReturnValue((async function* () {})()) | ||
| // @ts-ignore access private client | ||
| handler["client"].models.generateContentStream = stub | ||
| await handler.createMessage("instr", [] as any).next() | ||
| const config = stub.mock.calls[0][0].config | ||
| expect(config.tools).toEqual([{ urlContext: {} }, { googleSearch: {} }]) | ||
| }) | ||
|
|
||
| it("completePrompt passes config overrides without tools when URL context and grounding disabled", async () => { | ||
| const options = { | ||
| apiProvider: "gemini", | ||
| enableUrlContext: false, | ||
| enableGrounding: false, | ||
| } as ApiHandlerOptions | ||
| const handler = new GeminiHandler(options) | ||
| const stub = vi.fn().mockResolvedValue({ text: "ok" }) | ||
| // @ts-ignore access private client | ||
| handler["client"].models.generateContent = stub | ||
| const res = await handler.completePrompt("hi") | ||
| expect(res).toBe("ok") | ||
| const promptConfig = stub.mock.calls[0][0].config | ||
| expect(promptConfig.tools).toBeUndefined() | ||
| }) | ||
| }) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ import { | |
| type GenerateContentResponseUsageMetadata, | ||
| type GenerateContentParameters, | ||
| type GenerateContentConfig, | ||
| type GroundingMetadata, | ||
| } from "@google/genai" | ||
| import type { JWTInput } from "google-auth-library" | ||
|
|
||
|
|
@@ -13,6 +14,7 @@ import type { ApiHandlerOptions } from "../../shared/api" | |
| import { safeJsonParse } from "../../shared/safeJsonParse" | ||
|
|
||
| import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format" | ||
| import { t } from "i18next" | ||
| import type { ApiStream } from "../transform/stream" | ||
| import { getModelParams } from "../transform/model-params" | ||
|
|
||
|
|
@@ -67,72 +69,103 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl | |
|
|
||
| const contents = messages.map(convertAnthropicMessageToGemini) | ||
|
|
||
| const tools: GenerateContentConfig["tools"] = [] | ||
| if (this.options.enableUrlContext) { | ||
| tools.push({ urlContext: {} }) | ||
| } | ||
|
|
||
| if (this.options.enableGrounding) { | ||
| tools.push({ googleSearch: {} }) | ||
| } | ||
|
|
||
| const config: GenerateContentConfig = { | ||
| systemInstruction, | ||
| httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined, | ||
| thinkingConfig, | ||
| maxOutputTokens: this.options.modelMaxTokens ?? maxTokens ?? undefined, | ||
| temperature: this.options.modelTemperature ?? 0, | ||
| ...(tools.length > 0 ? { tools } : {}), | ||
| } | ||
daniel-lxs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| const params: GenerateContentParameters = { model, contents, config } | ||
|
|
||
| const result = await this.client.models.generateContentStream(params) | ||
| try { | ||
| const result = await this.client.models.generateContentStream(params) | ||
|
|
||
| let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined | ||
| let pendingGroundingMetadata: GroundingMetadata | undefined | ||
|
|
||
| let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined | ||
| for await (const chunk of result) { | ||
| // Process candidates and their parts to separate thoughts from content | ||
| if (chunk.candidates && chunk.candidates.length > 0) { | ||
| const candidate = chunk.candidates[0] | ||
|
|
||
| for await (const chunk of result) { | ||
| // Process candidates and their parts to separate thoughts from content | ||
| if (chunk.candidates && chunk.candidates.length > 0) { | ||
| const candidate = chunk.candidates[0] | ||
| if (candidate.content && candidate.content.parts) { | ||
| for (const part of candidate.content.parts) { | ||
| if (part.thought) { | ||
| // This is a thinking/reasoning part | ||
| if (part.text) { | ||
| yield { type: "reasoning", text: part.text } | ||
| } | ||
| } else { | ||
| // This is regular content | ||
| if (part.text) { | ||
| yield { type: "text", text: part.text } | ||
| if (candidate.groundingMetadata) { | ||
| pendingGroundingMetadata = candidate.groundingMetadata | ||
| } | ||
|
|
||
| if (candidate.content && candidate.content.parts) { | ||
| for (const part of candidate.content.parts) { | ||
| if (part.thought) { | ||
| // This is a thinking/reasoning part | ||
| if (part.text) { | ||
| yield { type: "reasoning", text: part.text } | ||
| } | ||
| } else { | ||
| // This is regular content | ||
| if (part.text) { | ||
| yield { type: "text", text: part.text } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Fallback to the original text property if no candidates structure | ||
| else if (chunk.text) { | ||
| yield { type: "text", text: chunk.text } | ||
| // Fallback to the original text property if no candidates structure | ||
| else if (chunk.text) { | ||
| yield { type: "text", text: chunk.text } | ||
| } | ||
|
|
||
| if (chunk.usageMetadata) { | ||
| lastUsageMetadata = chunk.usageMetadata | ||
| } | ||
| } | ||
|
|
||
| if (chunk.usageMetadata) { | ||
| lastUsageMetadata = chunk.usageMetadata | ||
| if (pendingGroundingMetadata) { | ||
| const citations = this.extractCitationsOnly(pendingGroundingMetadata) | ||
| if (citations) { | ||
| yield { type: "text", text: `\n\nSources: ${citations}` } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (lastUsageMetadata) { | ||
| const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 | ||
| const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 | ||
| const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount | ||
| const reasoningTokens = lastUsageMetadata.thoughtsTokenCount | ||
|
|
||
| yield { | ||
| type: "usage", | ||
| inputTokens, | ||
| outputTokens, | ||
| cacheReadTokens, | ||
| reasoningTokens, | ||
| totalCost: this.calculateCost({ info, inputTokens, outputTokens, cacheReadTokens }), | ||
| if (lastUsageMetadata) { | ||
| const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 | ||
| const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 | ||
| const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount | ||
| const reasoningTokens = lastUsageMetadata.thoughtsTokenCount | ||
|
|
||
| yield { | ||
| type: "usage", | ||
| inputTokens, | ||
| outputTokens, | ||
| cacheReadTokens, | ||
| reasoningTokens, | ||
| totalCost: this.calculateCost({ info, inputTokens, outputTokens, cacheReadTokens }), | ||
| } | ||
| } | ||
| } catch (error) { | ||
| if (error instanceof Error) { | ||
| throw new Error(t("common:errors.gemini.generate_stream", { error: error.message })) | ||
| } | ||
|
|
||
| throw error | ||
| } | ||
| } | ||
|
|
||
| override getModel() { | ||
| const modelId = this.options.apiModelId | ||
| let id = modelId && modelId in geminiModels ? (modelId as GeminiModelId) : geminiDefaultModelId | ||
| const info: ModelInfo = geminiModels[id] | ||
| let info: ModelInfo = geminiModels[id] | ||
| const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options }) | ||
|
|
||
| // The `:thinking` suffix indicates that the model is a "Hybrid" | ||
|
|
@@ -142,25 +175,69 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl | |
| return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } | ||
| } | ||
|
|
||
| private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { | ||
| const chunks = groundingMetadata?.groundingChunks | ||
|
|
||
| if (!chunks) { | ||
| return null | ||
| } | ||
|
|
||
| const citationLinks = chunks | ||
| .map((chunk, i) => { | ||
| const uri = chunk.web?.uri | ||
| if (uri) { | ||
| return `[${i + 1}](${uri})` | ||
| } | ||
| return null | ||
| }) | ||
| .filter((link): link is string => link !== null) | ||
|
|
||
| if (citationLinks.length > 0) { | ||
| return citationLinks.join(", ") | ||
| } | ||
|
|
||
| return null | ||
| } | ||
|
|
||
| async completePrompt(prompt: string): Promise<string> { | ||
| try { | ||
| const { id: model } = this.getModel() | ||
|
|
||
| const tools: GenerateContentConfig["tools"] = [] | ||
| if (this.options.enableUrlContext) { | ||
| tools.push({ urlContext: {} }) | ||
| } | ||
| if (this.options.enableGrounding) { | ||
| tools.push({ googleSearch: {} }) | ||
| } | ||
| const promptConfig: GenerateContentConfig = { | ||
| httpOptions: this.options.googleGeminiBaseUrl | ||
| ? { baseUrl: this.options.googleGeminiBaseUrl } | ||
| : undefined, | ||
| temperature: this.options.modelTemperature ?? 0, | ||
| ...(tools.length > 0 ? { tools } : {}), | ||
| } | ||
|
|
||
| const result = await this.client.models.generateContent({ | ||
| model, | ||
| contents: [{ role: "user", parts: [{ text: prompt }] }], | ||
| config: { | ||
| httpOptions: this.options.googleGeminiBaseUrl | ||
| ? { baseUrl: this.options.googleGeminiBaseUrl } | ||
| : undefined, | ||
| temperature: this.options.modelTemperature ?? 0, | ||
| }, | ||
| config: promptConfig, | ||
| }) | ||
|
|
||
| return result.text ?? "" | ||
| let text = result.text ?? "" | ||
|
|
||
| const candidate = result.candidates?.[0] | ||
| if (candidate?.groundingMetadata) { | ||
| const citations = this.extractCitationsOnly(candidate.groundingMetadata) | ||
| if (citations) { | ||
| text += `\n\nSources: ${citations}` | ||
|
||
| } | ||
| } | ||
|
|
||
| return text | ||
| } catch (error) { | ||
| if (error instanceof Error) { | ||
| throw new Error(`Gemini completion error: ${error.message}`) | ||
| throw new Error(t("common:errors.gemini.generate_complete_prompt", { error: error.message })) | ||
| } | ||
|
|
||
| throw error | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.