Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions packages/types/src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,19 @@ export const tokenUsageSchema = z.object({
})

export type TokenUsage = z.infer<typeof tokenUsageSchema>

/**
* QueuedMessage
*/

/**
* Represents a message that is queued to be sent when sending is enabled
*/
export interface QueuedMessage {
/** Unique identifier for the queued message */
id: string
/** The text content of the message */
text: string
/** Array of image data URLs attached to the message */
images: string[]
}
2 changes: 2 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ const lmStudioSchema = baseProviderSettingsSchema.extend({
const geminiSchema = apiModelIdProviderModelSchema.extend({
geminiApiKey: z.string().optional(),
googleGeminiBaseUrl: z.string().optional(),
enableUrlContext: z.boolean().optional(),
enableGrounding: z.boolean().optional(),
})

const geminiCliSchema = apiModelIdProviderModelSchema.extend({
Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/__tests__/bedrock-reasoning.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ describe("AwsBedrockHandler - Extended Thinking", () => {
expect(BedrockRuntimeClient).toHaveBeenCalledWith(
expect.objectContaining({
region: "us-east-1",
token: { token: "test-api-key-token" },
credentials: {
accessKeyId: "bedrock-user",
secretAccessKey: "bedrock-pwd",
sessionToken: "test-api-key-token",
},
authSchemePreference: ["httpBearerAuth"],
}),
)
Expand Down
137 changes: 137 additions & 0 deletions src/api/providers/__tests__/gemini-handler.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import { describe, it, expect, vi } from "vitest"
import { t } from "i18next"
import { GeminiHandler } from "../gemini"
import type { ApiHandlerOptions } from "../../../shared/api"

describe("GeminiHandler backend support", () => {
it("passes tools for URL context and grounding in config", async () => {
const options = {
apiProvider: "gemini",
enableUrlContext: true,
enableGrounding: true,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub
await handler.createMessage("instr", [] as any).next()
const config = stub.mock.calls[0][0].config
expect(config.tools).toEqual([{ urlContext: {} }, { googleSearch: {} }])
})

it("completePrompt passes config overrides without tools when URL context and grounding disabled", async () => {
const options = {
apiProvider: "gemini",
enableUrlContext: false,
enableGrounding: false,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockResolvedValue({ text: "ok" })
// @ts-ignore access private client
handler["client"].models.generateContent = stub
const res = await handler.completePrompt("hi")
expect(res).toBe("ok")
const promptConfig = stub.mock.calls[0][0].config
expect(promptConfig.tools).toBeUndefined()
})

describe("error scenarios", () => {
it("should handle grounding metadata extraction failure gracefully", async () => {
const options = {
apiProvider: "gemini",
enableGrounding: true,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)

const mockStream = async function* () {
yield {
candidates: [
{
groundingMetadata: {
// Invalid structure - missing groundingChunks
},
content: { parts: [{ text: "test response" }] },
},
],
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 },
}
}

const stub = vi.fn().mockReturnValue(mockStream())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

const messages = []
for await (const chunk of handler.createMessage("test", [] as any)) {
messages.push(chunk)
}

// Should still return the main content without sources
expect(messages.some((m) => m.type === "text" && m.text === "test response")).toBe(true)
expect(messages.some((m) => m.type === "text" && m.text?.includes("Sources:"))).toBe(false)
})

it("should handle malformed grounding metadata", async () => {
const options = {
apiProvider: "gemini",
enableGrounding: true,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)

const mockStream = async function* () {
yield {
candidates: [
{
groundingMetadata: {
groundingChunks: [
{ web: null }, // Missing URI
{ web: { uri: "https://example.com" } }, // Valid
{}, // Missing web property entirely
],
},
content: { parts: [{ text: "test response" }] },
},
],
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 },
}
}

const stub = vi.fn().mockReturnValue(mockStream())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

const messages = []
for await (const chunk of handler.createMessage("test", [] as any)) {
messages.push(chunk)
}

// Should only include valid citations
const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]"))
expect(sourceMessage).toBeDefined()
if (sourceMessage && "text" in sourceMessage) {
expect(sourceMessage.text).toContain("https://example.com")
expect(sourceMessage.text).not.toContain("[1]")
expect(sourceMessage.text).not.toContain("[3]")
}
})

it("should handle API errors when tools are enabled", async () => {
const options = {
apiProvider: "gemini",
enableUrlContext: true,
enableGrounding: true,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)

const mockError = new Error("API rate limit exceeded")
const stub = vi.fn().mockRejectedValue(mockError)
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await expect(async () => {
const generator = handler.createMessage("test", [] as any)
await generator.next()
}).rejects.toThrow(t("common:errors.gemini.generate_stream", { error: "API rate limit exceeded" }))
})
})
})
3 changes: 2 additions & 1 deletion src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Anthropic } from "@anthropic-ai/sdk"

import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"

import { t } from "i18next"
import { GeminiHandler } from "../gemini"

const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
Expand Down Expand Up @@ -129,7 +130,7 @@ describe("GeminiHandler", () => {
;(handler["client"].models.generateContent as any).mockRejectedValue(mockError)

await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Gemini completion error: Gemini API error",
t("common:errors.gemini.generate_complete_prompt", { error: "Gemini API error" }),
)
})

Expand Down
3 changes: 2 additions & 1 deletion src/api/providers/__tests__/vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { Anthropic } from "@anthropic-ai/sdk"

import { ApiStreamChunk } from "../../transform/stream"

import { t } from "i18next"
import { VertexHandler } from "../vertex"

describe("VertexHandler", () => {
Expand Down Expand Up @@ -105,7 +106,7 @@ describe("VertexHandler", () => {
;(handler["client"].models.generateContent as any).mockRejectedValue(mockError)

await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Gemini completion error: Vertex API error",
t("common:errors.gemini.generate_complete_prompt", { error: "Vertex API error" }),
)
})

Expand Down
6 changes: 5 additions & 1 deletion src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH

if (this.options.awsUseApiKey && this.options.awsApiKey) {
// Use API key/token-based authentication if enabled and API key is set
clientConfig.token = { token: this.options.awsApiKey }
clientConfig.credentials = {
accessKeyId: "bedrock-user",
secretAccessKey: "bedrock-pwd",
sessionToken: this.options.awsApiKey,
}
clientConfig.authSchemePreference = ["httpBearerAuth"] // Otherwise there's no end of credential problems.
} else if (this.options.awsUseProfile && this.options.awsProfile) {
// Use profile-based credentials if enabled and profile is set
Expand Down
Loading