Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/types/src/providers/roo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const rooModels = {
maxTokens: 16_384,
contextWindow: 262_144,
supportsImages: false,
supportsPromptCache: true,
supportsPromptCache: false, // Disabled to prevent context mixing between sessions
inputPrice: 0,
outputPrice: 0,
description:
Expand Down
144 changes: 143 additions & 1 deletion src/api/providers/__tests__/roo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ describe("RooHandler", () => {
expect.objectContaining({ role: "user", content: "Second message" }),
]),
}),
expect.any(Object), // Headers object
)
})
})
Expand Down Expand Up @@ -331,7 +332,7 @@ describe("RooHandler", () => {
expect(modelInfo.info.maxTokens).toBe(16_384)
expect(modelInfo.info.contextWindow).toBe(262_144)
expect(modelInfo.info.supportsImages).toBe(false)
expect(modelInfo.info.supportsPromptCache).toBe(true)
expect(modelInfo.info.supportsPromptCache).toBe(false) // Should be false now to prevent context mixing
expect(modelInfo.info.inputPrice).toBe(0)
expect(modelInfo.info.outputPrice).toBe(0)
})
Expand Down Expand Up @@ -361,6 +362,7 @@ describe("RooHandler", () => {
expect.not.objectContaining({
temperature: expect.anything(),
}),
expect.any(Object), // Headers object
)
})

Expand All @@ -378,6 +380,7 @@ describe("RooHandler", () => {
expect.objectContaining({
temperature: 0.9,
}),
expect.any(Object), // Headers object
)
})

Expand Down Expand Up @@ -433,4 +436,143 @@ describe("RooHandler", () => {
}).toThrow("Authentication required for Roo Code Cloud")
})
})

describe("session isolation", () => {
beforeEach(() => {
mockHasInstanceFn.mockReturnValue(true)
mockGetSessionTokenFn.mockReturnValue("test-session-token")
mockCreate.mockClear()
})

it("should include session isolation headers in requests", async () => {
handler = new RooHandler(mockOptions)
const stream = handler.createMessage(systemPrompt, messages)

// Consume the stream
for await (const _chunk of stream) {
// Just consume
}

// Verify that create was called with session isolation headers
expect(mockCreate).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
headers: expect.objectContaining({
"X-Session-Id": expect.any(String),
"X-Request-Id": expect.any(String),
"X-No-Cache": "true",
"Cache-Control": "no-store, no-cache, must-revalidate",
Pragma: "no-cache",
}),
}),
)
})

it("should generate unique session IDs for different handler instances", async () => {
const handler1 = new RooHandler(mockOptions)
const handler2 = new RooHandler(mockOptions)

// Create messages with both handlers
const stream1 = handler1.createMessage(systemPrompt, messages)
for await (const _chunk of stream1) {
// Consume
}

const stream2 = handler2.createMessage(systemPrompt, messages)
for await (const _chunk of stream2) {
// Consume
}

// Get the session IDs from the calls
const call1Headers = mockCreate.mock.calls[0][1].headers
const call2Headers = mockCreate.mock.calls[1][1].headers

// Session IDs should be different for different handler instances
expect(call1Headers["X-Session-Id"]).toBeDefined()
expect(call2Headers["X-Session-Id"]).toBeDefined()
expect(call1Headers["X-Session-Id"]).not.toBe(call2Headers["X-Session-Id"])
})

it("should generate unique request IDs for each request", async () => {
handler = new RooHandler(mockOptions)

// Make two requests with the same handler
const stream1 = handler.createMessage(systemPrompt, messages)
for await (const _chunk of stream1) {
// Consume
}

const stream2 = handler.createMessage(systemPrompt, messages)
for await (const _chunk of stream2) {
// Consume
}

// Get the request IDs from the calls
const call1Headers = mockCreate.mock.calls[0][1].headers
const call2Headers = mockCreate.mock.calls[1][1].headers

// Request IDs should be different for each request
expect(call1Headers["X-Request-Id"]).toBeDefined()
expect(call2Headers["X-Request-Id"]).toBeDefined()
expect(call1Headers["X-Request-Id"]).not.toBe(call2Headers["X-Request-Id"])

// But session IDs should be the same for the same handler
expect(call1Headers["X-Session-Id"]).toBe(call2Headers["X-Session-Id"])
})

it("should include metadata in request params", async () => {
handler = new RooHandler(mockOptions)
const stream = handler.createMessage(systemPrompt, messages)

for await (const _chunk of stream) {
// Consume
}

// Verify metadata is included in the request
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
metadata: expect.objectContaining({
session_id: expect.any(String),
request_id: expect.any(String),
timestamp: expect.any(String),
}),
}),
expect.any(Object),
)
})

it("should have prompt caching disabled for roo/sonic model", () => {
handler = new RooHandler(mockOptions)
const modelInfo = handler.getModel()

// Verify that prompt caching is disabled
expect(modelInfo.info.supportsPromptCache).toBe(false)
})

it("should maintain session ID consistency across multiple requests", async () => {
handler = new RooHandler(mockOptions)

// Make multiple requests
const requests = []
for (let i = 0; i < 3; i++) {
const stream = handler.createMessage(systemPrompt, messages)
for await (const _chunk of stream) {
// Consume
}
requests.push(i)
}

// All requests should have the same session ID
const sessionIds = mockCreate.mock.calls.map((call) => call[1].headers["X-Session-Id"])
const firstSessionId = sessionIds[0]

expect(sessionIds.every((id) => id === firstSessionId)).toBe(true)

// But all request IDs should be unique
const requestIds = mockCreate.mock.calls.map((call) => call[1].headers["X-Request-Id"])
const uniqueRequestIds = new Set(requestIds)

expect(uniqueRequestIds.size).toBe(requestIds.length)
})
})
})
58 changes: 57 additions & 1 deletion src/api/providers/roo.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { rooDefaultModelId, rooModels, type RooModelId } from "@roo-code/types"
import { CloudService } from "@roo-code/cloud"
import { randomUUID } from "crypto"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { t } from "../../i18n"
import { convertToOpenAiMessages } from "../transform/openai-format"

import type { ApiHandlerCreateMessageMetadata } from "../index"
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class RooHandler extends BaseOpenAiCompatibleProvider<RooModelId> {
private sessionId: string

constructor(options: ApiHandlerOptions) {
// Check if CloudService is available and get the session token.
if (!CloudService.hasInstance()) {
Expand All @@ -22,6 +27,9 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<RooModelId> {
throw new Error(t("common:errors.roo.authenticationRequired"))
}

// Generate a unique session ID for this handler instance to ensure request isolation
const sessionId = randomUUID()

super({
...options,
providerName: "Roo Code Cloud",
Expand All @@ -31,6 +39,53 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<RooModelId> {
providerModels: rooModels,
defaultTemperature: 0.7,
})

this.sessionId = sessionId
}

protected override createStream(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
) {
const {
id: model,
info: { maxTokens: max_tokens },
} = this.getModel()

// Generate unique request ID for this specific request
const requestId = randomUUID()

// Create the request with session isolation metadata
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model,
max_tokens,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
// Add session isolation metadata to prevent context mixing
metadata: {
session_id: this.sessionId,
request_id: requestId,
timestamp: new Date().toISOString(),
} as any,
}

// Only include temperature if explicitly set
if (this.options.modelTemperature !== undefined) {
params.temperature = this.options.modelTemperature
}

// Create the stream with additional headers for session isolation
return this.client.chat.completions.create(params, {
headers: {
"X-Session-Id": this.sessionId,
"X-Request-Id": requestId,
"X-No-Cache": "true", // Prevent any server-side caching
"Cache-Control": "no-store, no-cache, must-revalidate",
Pragma: "no-cache",
},
})
}

override async *createMessage(
Expand Down Expand Up @@ -78,13 +133,14 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<RooModelId> {
}

// Return the requested model ID even if not found, with fallback info.
// Note: supportsPromptCache is now false to prevent context mixing
return {
id: modelId as RooModelId,
info: {
maxTokens: 16_384,
contextWindow: 262_144,
supportsImages: false,
supportsPromptCache: true,
supportsPromptCache: false, // Disabled to prevent context mixing
inputPrice: 0,
outputPrice: 0,
},
Expand Down
Loading