diff --git a/.changeset/swift-carrots-doubt.md b/.changeset/swift-carrots-doubt.md new file mode 100644 index 0000000000..567ed0e109 --- /dev/null +++ b/.changeset/swift-carrots-doubt.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Roo Code Cloud diff --git a/.env.sample b/.env.sample index 4d6c24ac72..d89ef72792 100644 --- a/.env.sample +++ b/.env.sample @@ -1 +1,5 @@ POSTHOG_API_KEY=key-goes-here + +# Roo Code Cloud / Local Development +CLERK_BASE_URL=https://epic-chamois-85.clerk.accounts.dev +ROO_CODE_API_URL=http://localhost:3000 diff --git a/.github/workflows/code-qa.yml b/.github/workflows/code-qa.yml index 271ecc1f28..1ca5d8151a 100644 --- a/.github/workflows/code-qa.yml +++ b/.github/workflows/code-qa.yml @@ -133,10 +133,3 @@ jobs: - name: Run integration tests working-directory: apps/vscode-e2e run: xvfb-run -a pnpm test:ci - - qa: - needs: [check-translations, knip, compile, platform-unit-test, integration-test] - runs-on: ubuntu-latest - steps: - - name: NO-OP - run: echo "All tests passed." diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 71e7fb27e4..0784c8cbad 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -1,4 +1,4 @@ -name: "CodeQL Advanced" +name: CodeQL Advanced on: push: diff --git a/.github/workflows/marketplace-publish.yml b/.github/workflows/marketplace-publish.yml index 6e2dc09a01..bea1ecfa6d 100644 --- a/.github/workflows/marketplace-publish.yml +++ b/.github/workflows/marketplace-publish.yml @@ -1,4 +1,5 @@ name: Publish Extension + on: pull_request: types: [closed] diff --git a/.github/workflows/nightly-publish.yml b/.github/workflows/nightly-publish.yml index 5c28052426..14bb0212b1 100644 --- a/.github/workflows/nightly-publish.yml +++ b/.github/workflows/nightly-publish.yml @@ -1,8 +1,6 @@ name: Nightly Publish on: - # push: - # branches: [main] workflow_run: workflows: ["Code QA Roo Code"] types: diff --git a/packages/cloud/eslint.config.mjs b/packages/cloud/eslint.config.mjs new file mode 100644 index 0000000000..694bf73664 --- /dev/null +++ b/packages/cloud/eslint.config.mjs @@ -0,0 +1,4 @@ +import { config } from "@roo-code/config-eslint/base" + +/** @type {import("eslint").Linter.Config} */ +export default [...config] diff --git a/packages/cloud/package.json b/packages/cloud/package.json new file mode 100644 index 0000000000..62eff7b694 --- /dev/null +++ b/packages/cloud/package.json @@ -0,0 +1,25 @@ +{ + "name": "@roo-code/cloud", + "description": "Roo Code Cloud VSCode integration.", + "private": true, + "type": "module", + "exports": "./src/index.ts", + "scripts": { + "lint": "eslint src --ext=ts --max-warnings=0", + "check-types": "tsc --noEmit", + "test": "vitest --globals --run", + "clean": "rimraf dist .turbo" + }, + "dependencies": { + "@roo-code/telemetry": "workspace:^", + "@roo-code/types": "workspace:^", + "axios": "^1.7.4" + }, + "devDependencies": { + "@roo-code/config-eslint": "workspace:^", + "@roo-code/config-typescript": "workspace:^", + "@types/node": "^22.15.20", + "@types/vscode": "^1.84.0", + "vitest": "^3.1.3" + } +} diff --git a/packages/cloud/src/AuthService.ts b/packages/cloud/src/AuthService.ts new file mode 100644 index 0000000000..85954a68aa --- /dev/null +++ b/packages/cloud/src/AuthService.ts @@ -0,0 +1,395 @@ +import crypto from "crypto" +import EventEmitter from "events" + +import axios from "axios" +import * as vscode from "vscode" + +import type { CloudUserInfo } from "@roo-code/types" + +import { CloudServiceCallbacks } from "./types" +import { getClerkBaseUrl, getRooCodeApiUrl } from "./Config" +import { RefreshTimer } from "./RefreshTimer" + +export interface AuthServiceEvents { + "active-session": [data: { previousState: AuthState }] + "logged-out": [data: { previousState: AuthState }] +} + +const CLIENT_TOKEN_KEY = "clerk-client-token" +const SESSION_ID_KEY = "clerk-session-id" +const AUTH_STATE_KEY = "clerk-auth-state" + +type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-session" + +export class AuthService extends EventEmitter { + private context: vscode.ExtensionContext + private userChanged: CloudServiceCallbacks["userChanged"] + private timer: RefreshTimer + private state: AuthState = "initializing" + + private clientToken: string | null = null + private sessionToken: string | null = null + private sessionId: string | null = null + + constructor(context: vscode.ExtensionContext, userChanged: CloudServiceCallbacks["userChanged"]) { + super() + + this.context = context + this.userChanged = userChanged + + this.timer = new RefreshTimer({ + callback: async () => { + await this.refreshSession() + return true + }, + successInterval: 50_000, + initialBackoffMs: 1_000, + maxBackoffMs: 300_000, + }) + } + + /** + * Initialize the auth state + * + * This method loads tokens from storage and determines the current auth state. + * It also starts the refresh timer if we have an active session. + */ + public async initialize(): Promise { + if (this.state !== "initializing") { + console.log("[auth] initialize() called after already initialized") + return + } + + try { + this.clientToken = (await this.context.secrets.get(CLIENT_TOKEN_KEY)) || null + this.sessionId = this.context.globalState.get(SESSION_ID_KEY) || null + + // Determine initial state. + if (!this.clientToken || !this.sessionId) { + // TODO: it may be possible to get a new session with the client, + // but the obvious Clerk endpoints don't support that. + const previousState = this.state + this.state = "logged-out" + this.emit("logged-out", { previousState }) + } else { + this.state = "inactive-session" + this.timer.start() + } + + console.log(`[auth] Initialized with state: ${this.state}`) + } catch (error) { + console.error(`[auth] Error initializing AuthService: ${error}`) + this.state = "logged-out" + } + } + + /** + * Start the login process + * + * This method initiates the authentication flow by generating a state parameter + * and opening the browser to the authorization URL. + */ + public async login(): Promise { + try { + // Generate a cryptographically random state parameter. + const state = crypto.randomBytes(16).toString("hex") + await this.context.globalState.update(AUTH_STATE_KEY, state) + const uri = vscode.Uri.parse(`${getRooCodeApiUrl()}/extension/sign-in?state=${state}`) + await vscode.env.openExternal(uri) + } catch (error) { + console.error(`[auth] Error initiating Roo Code Cloud auth: ${error}`) + throw new Error(`Failed to initiate Roo Code Cloud authentication: ${error}`) + } + } + + /** + * Handle the callback from Roo Code Cloud + * + * This method is called when the user is redirected back to the extension + * after authenticating with Roo Code Cloud. + * + * @param code The authorization code from the callback + * @param state The state parameter from the callback + */ + public async handleCallback(code: string | null, state: string | null): Promise { + if (!code || !state) { + vscode.window.showInformationMessage("Invalid Roo Code Cloud sign in url") + return + } + + try { + // Validate state parameter to prevent CSRF attacks. + const storedState = this.context.globalState.get(AUTH_STATE_KEY) + + if (state !== storedState) { + console.log("[auth] State mismatch in callback") + throw new Error("Invalid state parameter. Authentication request may have been tampered with.") + } + + const { clientToken, sessionToken, sessionId } = await this.clerkSignIn(code) + + await this.context.secrets.store(CLIENT_TOKEN_KEY, clientToken) + await this.context.globalState.update(SESSION_ID_KEY, sessionId) + + this.clientToken = clientToken + this.sessionId = sessionId + this.sessionToken = sessionToken + + const previousState = this.state + this.state = "active-session" + this.emit("active-session", { previousState }) + this.timer.start() + + if (this.userChanged) { + this.getUserInfo().then(this.userChanged) + } + + vscode.window.showInformationMessage("Successfully authenticated with Roo Code Cloud") + console.log("[auth] Successfully authenticated with Roo Code Cloud") + } catch (error) { + console.log(`[auth] Error handling Roo Code Cloud callback: ${error}`) + const previousState = this.state + this.state = "logged-out" + this.emit("logged-out", { previousState }) + throw new Error(`Failed to handle Roo Code Cloud callback: ${error}`) + } + } + + /** + * Log out + * + * This method removes all stored tokens and stops the refresh timer. + */ + public async logout(): Promise { + try { + this.timer.stop() + + await this.context.secrets.delete(CLIENT_TOKEN_KEY) + await this.context.globalState.update(SESSION_ID_KEY, undefined) + await this.context.globalState.update(AUTH_STATE_KEY, undefined) + + const oldClientToken = this.clientToken + const oldSessionId = this.sessionId + + this.clientToken = null + this.sessionToken = null + this.sessionId = null + const previousState = this.state + this.state = "logged-out" + this.emit("logged-out", { previousState }) + + if (oldClientToken && oldSessionId) { + await this.clerkLogout(oldClientToken, oldSessionId) + } + + if (this.userChanged) { + this.getUserInfo().then(this.userChanged) + } + + vscode.window.showInformationMessage("Logged out from Roo Code Cloud") + console.log("[auth] Logged out from Roo Code Cloud") + } catch (error) { + console.log(`[auth] Error logging out from Roo Code Cloud: ${error}`) + throw new Error(`Failed to log out from Roo Code Cloud: ${error}`) + } + } + + public getState(): AuthState { + return this.state + } + + public getSessionToken(): string | undefined { + if (this.state === "active-session" && this.sessionToken) { + return this.sessionToken + } + + return + } + + /** + * Check if the user is authenticated + * + * @returns True if the user is authenticated (has an active or inactive session) + */ + public isAuthenticated(): boolean { + return this.state === "active-session" || this.state === "inactive-session" + } + + public hasActiveSession(): boolean { + return this.state === "active-session" + } + + /** + * Refresh the session + * + * This method refreshes the session token using the client token. + */ + private async refreshSession() { + if (!this.sessionId || !this.clientToken) { + console.log("[auth] Cannot refresh session: missing session ID or token") + this.state = "inactive-session" + return + } + + const previousState = this.state + this.sessionToken = await this.clerkCreateSessionToken() + this.state = "active-session" + + if (previousState !== "active-session") { + this.emit("active-session", { previousState }) + + if (this.userChanged) { + this.getUserInfo().then(this.userChanged) + } + } + } + + /** + * Extract user information from the ID token + * + * @returns User information from ID token claims or null if no ID token available + */ + public async getUserInfo(): Promise { + if (!this.clientToken) { + return undefined + } + + return await this.clerkMe() + } + + private async clerkSignIn( + ticket: string, + ): Promise<{ clientToken: string; sessionToken: string; sessionId: string }> { + const formData = new URLSearchParams() + formData.append("strategy", "ticket") + formData.append("ticket", ticket) + + const response = await axios.post(`${getClerkBaseUrl()}/v1/client/sign_ins`, formData, { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": this.userAgent(), + }, + }) + + // 3. Extract the client token from the Authorization header. + const clientToken = response.headers.authorization + + if (!clientToken) { + throw new Error("No authorization header found in the response") + } + + // 4. Find the session using created_session_id and extract the JWT. + const createdSessionId = response.data?.response?.created_session_id + + if (!createdSessionId) { + throw new Error("No session ID found in the response") + } + + // Find the session in the client sessions array. + const session = response.data?.client?.sessions?.find((s: { id: string }) => s.id === createdSessionId) + + if (!session) { + throw new Error("Session not found in the response") + } + + // Extract the session token (JWT) and store it. + const sessionToken = session.last_active_token?.jwt + + if (!sessionToken) { + throw new Error("Session does not have a token") + } + + return { clientToken, sessionToken, sessionId: session.id } + } + + private async clerkCreateSessionToken(): Promise { + const formData = new URLSearchParams() + formData.append("_is_native", "1") + + const response = await axios.post( + `${getClerkBaseUrl()}/v1/client/sessions/${this.sessionId}/tokens`, + formData, + { + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Authorization: `Bearer ${this.clientToken}`, + "User-Agent": this.userAgent(), + }, + }, + ) + + const sessionToken = response.data?.jwt + + if (!sessionToken) { + throw new Error("No JWT found in refresh response") + } + + return sessionToken + } + + private async clerkMe(): Promise { + const response = await axios.get(`${getClerkBaseUrl()}/v1/me`, { + headers: { + Authorization: `Bearer ${this.clientToken}`, + "User-Agent": this.userAgent(), + }, + }) + + const userData = response.data?.response + + if (!userData) { + throw new Error("No response user data") + } + + const userInfo: CloudUserInfo = {} + + userInfo.name = `${userData?.first_name} ${userData?.last_name}` + const primaryEmailAddressId = userData?.primary_email_address_id + const emailAddresses = userData?.email_addresses + + if (primaryEmailAddressId && emailAddresses) { + userInfo.email = emailAddresses.find( + (email: { id: string }) => primaryEmailAddressId === email?.id, + )?.email_address + } + + userInfo.picture = userData?.image_url + return userInfo + } + + private async clerkLogout(clientToken: string, sessionId: string): Promise { + const formData = new URLSearchParams() + formData.append("_is_native", "1") + + await axios.post(`${getClerkBaseUrl()}/v1/client/sessions/${sessionId}/remove`, formData, { + headers: { + Authorization: `Bearer ${clientToken}`, + "User-Agent": this.userAgent(), + }, + }) + } + + private userAgent(): string { + return `Roo-Code ${this.context.extension?.packageJSON?.version}` + } + + private static _instance: AuthService | null = null + + static get instance() { + if (!this._instance) { + throw new Error("AuthService not initialized") + } + + return this._instance + } + + static async createInstance(context: vscode.ExtensionContext, userChanged: CloudServiceCallbacks["userChanged"]) { + if (this._instance) { + throw new Error("AuthService instance already created") + } + + this._instance = new AuthService(context, userChanged) + await this._instance.initialize() + return this._instance + } +} diff --git a/packages/cloud/src/CloudService.ts b/packages/cloud/src/CloudService.ts new file mode 100644 index 0000000000..72cbb70b22 --- /dev/null +++ b/packages/cloud/src/CloudService.ts @@ -0,0 +1,157 @@ +import * as vscode from "vscode" + +import type { CloudUserInfo, TelemetryEvent, OrganizationAllowList } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" + +import { CloudServiceCallbacks } from "./types" +import { AuthService } from "./AuthService" +import { SettingsService } from "./SettingsService" +import { TelemetryClient } from "./TelemetryClient" + +export class CloudService { + private static _instance: CloudService | null = null + + private context: vscode.ExtensionContext + private callbacks: CloudServiceCallbacks + private authService: AuthService | null = null + private settingsService: SettingsService | null = null + private telemetryClient: TelemetryClient | null = null + private isInitialized = false + + private constructor(context: vscode.ExtensionContext, callbacks: CloudServiceCallbacks) { + this.context = context + this.callbacks = callbacks + } + + public async initialize(): Promise { + if (this.isInitialized) { + return + } + + try { + this.authService = await AuthService.createInstance(this.context, (userInfo) => { + this.callbacks.userChanged?.(userInfo) + }) + + this.settingsService = await SettingsService.createInstance(this.context, () => + this.callbacks.settingsChanged?.(), + ) + + this.telemetryClient = new TelemetryClient(this.authService) + + try { + TelemetryService.instance.register(this.telemetryClient) + } catch (error) { + console.warn("[CloudService] Failed to register TelemetryClient:", error) + } + + this.isInitialized = true + } catch (error) { + console.error("[CloudService] Failed to initialize:", error) + throw new Error(`Failed to initialize CloudService: ${error}`) + } + } + + // AuthService + + public async login(): Promise { + this.ensureInitialized() + return this.authService!.login() + } + + public async logout(): Promise { + this.ensureInitialized() + return this.authService!.logout() + } + + public isAuthenticated(): boolean { + this.ensureInitialized() + return this.authService!.isAuthenticated() + } + + public hasActiveSession(): boolean { + this.ensureInitialized() + return this.authService!.hasActiveSession() + } + + public async getUserInfo(): Promise { + this.ensureInitialized() + return this.authService!.getUserInfo() + } + + public getAuthState(): string { + this.ensureInitialized() + return this.authService!.getState() + } + + public async handleAuthCallback(code: string | null, state: string | null): Promise { + this.ensureInitialized() + return this.authService!.handleCallback(code, state) + } + + // SettingsService + + public getAllowList(): OrganizationAllowList { + this.ensureInitialized() + return this.settingsService!.getAllowList() + } + + // TelemetryClient + + public captureEvent(event: TelemetryEvent): void { + this.ensureInitialized() + this.telemetryClient!.capture(event) + } + + // Lifecycle + + public dispose(): void { + if (this.settingsService) { + this.settingsService.dispose() + } + + this.isInitialized = false + } + + private ensureInitialized(): void { + if (!this.isInitialized || !this.authService || !this.settingsService || !this.telemetryClient) { + throw new Error("CloudService not initialized.") + } + } + + static get instance(): CloudService { + if (!this._instance) { + throw new Error("CloudService not initialized") + } + + return this._instance + } + + static async createInstance( + context: vscode.ExtensionContext, + callbacks: CloudServiceCallbacks = {}, + ): Promise { + if (this._instance) { + throw new Error("CloudService instance already created") + } + + this._instance = new CloudService(context, callbacks) + await this._instance.initialize() + return this._instance + } + + static hasInstance(): boolean { + return this._instance !== null && this._instance.isInitialized + } + + static resetInstance(): void { + if (this._instance) { + this._instance.dispose() + this._instance = null + } + } + + static isEnabled(): boolean { + return !!this._instance?.isAuthenticated() + } +} diff --git a/packages/cloud/src/Config.ts b/packages/cloud/src/Config.ts new file mode 100644 index 0000000000..0205e5b0e3 --- /dev/null +++ b/packages/cloud/src/Config.ts @@ -0,0 +1,2 @@ +export const getClerkBaseUrl = () => process.env.CLERK_BASE_URL || "https://clerk.roocode.com" +export const getRooCodeApiUrl = () => process.env.ROO_CODE_API_URL || "https://app.roocode.com" diff --git a/src/utils/refresh-timer.ts b/packages/cloud/src/RefreshTimer.ts similarity index 99% rename from src/utils/refresh-timer.ts rename to packages/cloud/src/RefreshTimer.ts index 3138031665..e7294222d7 100644 --- a/src/utils/refresh-timer.ts +++ b/packages/cloud/src/RefreshTimer.ts @@ -146,7 +146,7 @@ export class RefreshTimer { const result = await this.callback() this.scheduleNextAttempt(result) - } catch (error) { + } catch (_error) { // Treat errors as failed attempts this.scheduleNextAttempt(false) } diff --git a/packages/cloud/src/RooCodeTelemetryClient.ts b/packages/cloud/src/RooCodeTelemetryClient.ts new file mode 100644 index 0000000000..661c5179b3 --- /dev/null +++ b/packages/cloud/src/RooCodeTelemetryClient.ts @@ -0,0 +1,87 @@ +import { TelemetryEventName, type TelemetryEvent, rooCodeTelemetryEventSchema } from "@roo-code/types" +import { BaseTelemetryClient } from "@roo-code/telemetry" + +import { getRooCodeApiUrl } from "./Config" +import { AuthService } from "./AuthService" + +export class RooCodeTelemetryClient extends BaseTelemetryClient { + constructor( + private authService: AuthService, + debug = false, + ) { + super( + { + type: "exclude", + events: [TelemetryEventName.TASK_CONVERSATION_MESSAGE], + }, + debug, + ) + } + + private async fetch(path: string, options: RequestInit) { + if (!this.authService.isAuthenticated()) { + return + } + + const token = this.authService.getSessionToken() + + if (!token) { + console.error(`[RooCodeTelemetryClient#fetch] Unauthorized: No session token available.`) + return + } + + const response = await fetch(`${getRooCodeApiUrl()}/api/${path}`, { + ...options, + headers: { Authorization: `Bearer ${token}`, "Content-Type": "application/json" }, + }) + + if (!response.ok) { + console.error( + `[RooCodeTelemetryClient#fetch] ${options.method} ${path} -> ${response.status} ${response.statusText}`, + ) + } + } + + public override async capture(event: TelemetryEvent) { + if (!this.isTelemetryEnabled() || !this.isEventCapturable(event.event)) { + if (this.debug) { + console.info(`[RooCodeTelemetryClient#capture] Skipping event: ${event.event}`) + } + + return + } + + const payload = { + type: event.event, + properties: await this.getEventProperties(event), + } + + if (this.debug) { + console.info(`[RooCodeTelemetryClient#capture] ${JSON.stringify(payload)}`) + } + + const result = rooCodeTelemetryEventSchema.safeParse(payload) + + if (!result.success) { + console.error( + `[RooCodeTelemetryClient#capture] Invalid telemetry event: ${result.error.message} - ${JSON.stringify(payload)}`, + ) + + return + } + + try { + await this.fetch(`events`, { method: "POST", body: JSON.stringify(result.data) }) + } catch (error) { + console.error(`[RooCodeTelemetryClient#capture] Error sending telemetry event: ${error}`) + } + } + + public override updateTelemetryState(_didUserOptIn: boolean) {} + + public override isTelemetryEnabled(): boolean { + return true + } + + public override async shutdown() {} +} diff --git a/packages/cloud/src/SettingsService.ts b/packages/cloud/src/SettingsService.ts new file mode 100644 index 0000000000..516654e19d --- /dev/null +++ b/packages/cloud/src/SettingsService.ts @@ -0,0 +1,137 @@ +import * as vscode from "vscode" + +import { + ORGANIZATION_ALLOW_ALL, + OrganizationAllowList, + OrganizationSettings, + organizationSettingsSchema, +} from "@roo-code/types" + +import { getRooCodeApiUrl } from "./Config" +import { AuthService } from "./AuthService" +import { RefreshTimer } from "./RefreshTimer" + +const ORGANIZATION_SETTINGS_CACHE_KEY = "organization-settings" + +export class SettingsService { + private static _instance: SettingsService | null = null + + private context: vscode.ExtensionContext + private authService: AuthService + private settings: OrganizationSettings | undefined = undefined + private timer: RefreshTimer + + private constructor(context: vscode.ExtensionContext, authService: AuthService, callback: () => void) { + this.context = context + this.authService = authService + + this.timer = new RefreshTimer({ + callback: async () => { + await this.fetchSettings(callback) + return true + }, + successInterval: 30000, + initialBackoffMs: 1000, + maxBackoffMs: 30000, + }) + } + + public initialize(): void { + this.loadCachedSettings() + + this.authService.on("active-session", () => { + this.timer.start() + }) + + this.authService.on("logged-out", () => { + this.timer.stop() + this.removeSettings() + }) + + if (this.authService.hasActiveSession()) { + this.timer.start() + } + } + + private async fetchSettings(callback: () => void): Promise { + const token = this.authService.getSessionToken() + + if (!token) { + return + } + + try { + const response = await fetch(`${getRooCodeApiUrl()}/api/organization-settings`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }) + + if (!response.ok) { + console.error(`Failed to fetch organization settings: ${response.status} ${response.statusText}`) + return + } + + const data = await response.json() + const result = organizationSettingsSchema.safeParse(data) + + if (!result.success) { + console.error("Invalid organization settings format:", result.error) + return + } + + const newSettings = result.data + + if (!this.settings || this.settings.version !== newSettings.version) { + this.settings = newSettings + await this.cacheSettings() + callback() + } + } catch (error) { + console.error("Error fetching organization settings:", error) + } + } + + private async cacheSettings(): Promise { + await this.context.globalState.update(ORGANIZATION_SETTINGS_CACHE_KEY, this.settings) + } + + private loadCachedSettings(): void { + this.settings = this.context.globalState.get(ORGANIZATION_SETTINGS_CACHE_KEY) + } + + public getAllowList(): OrganizationAllowList { + return this.settings?.allowList || ORGANIZATION_ALLOW_ALL + } + + public getSettings(): OrganizationSettings | undefined { + return this.settings + } + + public async removeSettings(): Promise { + this.settings = undefined + await this.cacheSettings() + } + + public dispose(): void { + this.timer.stop() + } + + static get instance() { + if (!this._instance) { + throw new Error("SettingsService not initialized") + } + + return this._instance + } + + static async createInstance(context: vscode.ExtensionContext, callback: () => void) { + if (this._instance) { + throw new Error("SettingsService instance already created") + } + + this._instance = new SettingsService(context, AuthService.instance, callback) + this._instance.initialize() + return this._instance + } +} diff --git a/packages/cloud/src/TelemetryClient.ts b/packages/cloud/src/TelemetryClient.ts new file mode 100644 index 0000000000..6db8b1096d --- /dev/null +++ b/packages/cloud/src/TelemetryClient.ts @@ -0,0 +1,87 @@ +import { TelemetryEventName, type TelemetryEvent, rooCodeTelemetryEventSchema } from "@roo-code/types" +import { BaseTelemetryClient } from "@roo-code/telemetry" + +import { getRooCodeApiUrl } from "./Config" +import { AuthService } from "./AuthService" + +export class TelemetryClient extends BaseTelemetryClient { + constructor( + private authService: AuthService, + debug = false, + ) { + super( + { + type: "exclude", + events: [TelemetryEventName.TASK_CONVERSATION_MESSAGE], + }, + debug, + ) + } + + private async fetch(path: string, options: RequestInit) { + if (!this.authService.isAuthenticated()) { + return + } + + const token = this.authService.getSessionToken() + + if (!token) { + console.error(`[TelemetryClient#fetch] Unauthorized: No session token available.`) + return + } + + const response = await fetch(`${getRooCodeApiUrl()}/api/${path}`, { + ...options, + headers: { Authorization: `Bearer ${token}`, "Content-Type": "application/json" }, + }) + + if (!response.ok) { + console.error( + `[TelemetryClient#fetch] ${options.method} ${path} -> ${response.status} ${response.statusText}`, + ) + } + } + + public override async capture(event: TelemetryEvent) { + if (!this.isTelemetryEnabled() || !this.isEventCapturable(event.event)) { + if (this.debug) { + console.info(`[TelemetryClient#capture] Skipping event: ${event.event}`) + } + + return + } + + const payload = { + type: event.event, + properties: await this.getEventProperties(event), + } + + if (this.debug) { + console.info(`[TelemetryClient#capture] ${JSON.stringify(payload)}`) + } + + const result = rooCodeTelemetryEventSchema.safeParse(payload) + + if (!result.success) { + console.error( + `[TelemetryClient#capture] Invalid telemetry event: ${result.error.message} - ${JSON.stringify(payload)}`, + ) + + return + } + + try { + await this.fetch(`events`, { method: "POST", body: JSON.stringify(result.data) }) + } catch (error) { + console.error(`[TelemetryClient#capture] Error sending telemetry event: ${error}`) + } + } + + public override updateTelemetryState(_didUserOptIn: boolean) {} + + public override isTelemetryEnabled(): boolean { + return true + } + + public override async shutdown() {} +} diff --git a/packages/cloud/src/__mocks__/vscode.ts b/packages/cloud/src/__mocks__/vscode.ts new file mode 100644 index 0000000000..df636967a1 --- /dev/null +++ b/packages/cloud/src/__mocks__/vscode.ts @@ -0,0 +1,50 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { vi } from "vitest" + +export const window = { + showInformationMessage: vi.fn(), + showErrorMessage: vi.fn(), +} + +export const env = { + openExternal: vi.fn(), +} + +export const Uri = { + parse: vi.fn((uri: string) => ({ toString: () => uri })), +} + +export interface ExtensionContext { + secrets: { + get: (key: string) => Promise + store: (key: string, value: string) => Promise + delete: (key: string) => Promise + } + globalState: { + get: (key: string) => T | undefined + update: (key: string, value: any) => Promise + } + extension?: { + packageJSON?: { + version?: string + } + } +} + +// Mock implementation for tests +export const mockExtensionContext: ExtensionContext = { + secrets: { + get: vi.fn().mockResolvedValue(undefined), + store: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined), + }, + globalState: { + get: vi.fn().mockReturnValue(undefined), + update: vi.fn().mockResolvedValue(undefined), + }, + extension: { + packageJSON: { + version: "1.0.0", + }, + }, +} diff --git a/packages/cloud/src/__tests__/CloudService.test.ts b/packages/cloud/src/__tests__/CloudService.test.ts new file mode 100644 index 0000000000..8b34ee21c1 --- /dev/null +++ b/packages/cloud/src/__tests__/CloudService.test.ts @@ -0,0 +1,238 @@ +// npx vitest run src/__tests__/CloudService.test.ts + +import * as vscode from "vscode" + +import { CloudService } from "../CloudService" +import { AuthService } from "../AuthService" +import { SettingsService } from "../SettingsService" +import { TelemetryService } from "@roo-code/telemetry" +import { CloudServiceCallbacks } from "../types" + +vi.mock("vscode", () => ({ + ExtensionContext: vi.fn(), + window: { + showInformationMessage: vi.fn(), + showErrorMessage: vi.fn(), + }, + env: { + openExternal: vi.fn(), + }, + Uri: { + parse: vi.fn(), + }, +})) + +vi.mock("@roo-code/telemetry") + +vi.mock("../AuthService") + +vi.mock("../SettingsService") + +describe("CloudService", () => { + let mockContext: vscode.ExtensionContext + let mockAuthService: { + initialize: ReturnType + login: ReturnType + logout: ReturnType + isAuthenticated: ReturnType + hasActiveSession: ReturnType + getUserInfo: ReturnType + getState: ReturnType + getSessionToken: ReturnType + handleCallback: ReturnType + on: ReturnType + off: ReturnType + once: ReturnType + emit: ReturnType + } + let mockSettingsService: { + initialize: ReturnType + getSettings: ReturnType + getAllowList: ReturnType + dispose: ReturnType + } + let mockTelemetryService: { + hasInstance: ReturnType + instance: { + register: ReturnType + } + } + + beforeEach(() => { + CloudService.resetInstance() + + mockContext = { + secrets: { + get: vi.fn(), + store: vi.fn(), + delete: vi.fn(), + }, + globalState: { + get: vi.fn(), + update: vi.fn(), + }, + extension: { + packageJSON: { + version: "1.0.0", + }, + }, + } as unknown as vscode.ExtensionContext + + mockAuthService = { + initialize: vi.fn(), + login: vi.fn(), + logout: vi.fn(), + isAuthenticated: vi.fn().mockReturnValue(false), + hasActiveSession: vi.fn().mockReturnValue(false), + getUserInfo: vi.fn(), + getState: vi.fn().mockReturnValue("logged-out"), + getSessionToken: vi.fn(), + handleCallback: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + emit: vi.fn(), + } + + mockSettingsService = { + initialize: vi.fn(), + getSettings: vi.fn(), + getAllowList: vi.fn(), + dispose: vi.fn(), + } + + mockTelemetryService = { + hasInstance: vi.fn().mockReturnValue(true), + instance: { + register: vi.fn(), + }, + } + + vi.mocked(AuthService.createInstance).mockResolvedValue(mockAuthService as unknown as AuthService) + Object.defineProperty(AuthService, "instance", { get: () => mockAuthService, configurable: true }) + + vi.mocked(SettingsService.createInstance).mockResolvedValue(mockSettingsService as unknown as SettingsService) + Object.defineProperty(SettingsService, "instance", { get: () => mockSettingsService, configurable: true }) + + vi.mocked(TelemetryService.hasInstance).mockReturnValue(true) + Object.defineProperty(TelemetryService, "instance", { + get: () => mockTelemetryService.instance, + configurable: true, + }) + }) + + afterEach(() => { + vi.clearAllMocks() + CloudService.resetInstance() + }) + + describe("createInstance", () => { + it("should create and initialize CloudService instance", async () => { + const callbacks = { userChanged: vi.fn(), settingsChanged: vi.fn() } + const cloudService = await CloudService.createInstance(mockContext, callbacks) + + expect(cloudService).toBeInstanceOf(CloudService) + expect(AuthService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function)) + expect(SettingsService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function)) + }) + + it("should throw error if instance already exists", async () => { + await CloudService.createInstance(mockContext) + + await expect(CloudService.createInstance(mockContext)).rejects.toThrow( + "CloudService instance already created", + ) + }) + }) + + describe("authentication methods", () => { + let cloudService: CloudService + let callbacks: CloudServiceCallbacks + + beforeEach(async () => { + callbacks = { userChanged: vi.fn(), settingsChanged: vi.fn() } + cloudService = await CloudService.createInstance(mockContext, callbacks) + }) + + it("should delegate login to AuthService", async () => { + await cloudService.login() + expect(mockAuthService.login).toHaveBeenCalled() + }) + + it("should delegate logout to AuthService", async () => { + await cloudService.logout() + expect(mockAuthService.logout).toHaveBeenCalled() + }) + + it("should delegate isAuthenticated to AuthService", () => { + const result = cloudService.isAuthenticated() + expect(mockAuthService.isAuthenticated).toHaveBeenCalled() + expect(result).toBe(false) + }) + + it("should delegate hasActiveSession to AuthService", () => { + const result = cloudService.hasActiveSession() + expect(mockAuthService.hasActiveSession).toHaveBeenCalled() + expect(result).toBe(false) + }) + + it("should delegate getUserInfo to AuthService", async () => { + await cloudService.getUserInfo() + expect(mockAuthService.getUserInfo).toHaveBeenCalled() + }) + + it("should delegate getAuthState to AuthService", () => { + const result = cloudService.getAuthState() + expect(mockAuthService.getState).toHaveBeenCalled() + expect(result).toBe("logged-out") + }) + + it("should delegate handleAuthCallback to AuthService", async () => { + await cloudService.handleAuthCallback("code", "state") + expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state") + }) + }) + + describe("organization settings methods", () => { + let cloudService: CloudService + + beforeEach(async () => { + cloudService = await CloudService.createInstance(mockContext) + }) + + it("should delegate getAllowList to SettingsService", () => { + cloudService.getAllowList() + expect(mockSettingsService.getAllowList).toHaveBeenCalled() + }) + }) + + describe("error handling", () => { + it("should throw error when accessing methods before initialization", () => { + expect(() => CloudService.instance.login()).toThrow("CloudService not initialized") + }) + + it("should throw error when accessing instance before creation", () => { + expect(() => CloudService.instance).toThrow("CloudService not initialized") + }) + }) + + describe("hasInstance", () => { + it("should return false when no instance exists", () => { + expect(CloudService.hasInstance()).toBe(false) + }) + + it("should return true when instance exists and is initialized", async () => { + await CloudService.createInstance(mockContext) + expect(CloudService.hasInstance()).toBe(true) + }) + }) + + describe("dispose", () => { + it("should dispose of all services and clean up", async () => { + const cloudService = await CloudService.createInstance(mockContext) + cloudService.dispose() + + expect(mockSettingsService.dispose).toHaveBeenCalled() + }) + }) +}) diff --git a/src/utils/__tests__/refresh-timer.test.ts b/packages/cloud/src/__tests__/RefreshTimer.test.ts similarity index 86% rename from src/utils/__tests__/refresh-timer.test.ts rename to packages/cloud/src/__tests__/RefreshTimer.test.ts index 11911494f6..f8b306b716 100644 --- a/src/utils/__tests__/refresh-timer.test.ts +++ b/packages/cloud/src/__tests__/RefreshTimer.test.ts @@ -1,27 +1,27 @@ -import { RefreshTimer } from "../refresh-timer" +// npx vitest run --globals src/__tests__/RefreshTimer.test.ts -// Mock timers -jest.useFakeTimers() +import { Mock } from "vitest" + +import { RefreshTimer } from "../RefreshTimer" + +vi.useFakeTimers() describe("RefreshTimer", () => { - let mockCallback: jest.Mock + let mockCallback: Mock let refreshTimer: RefreshTimer beforeEach(() => { - // Reset mocks before each test - mockCallback = jest.fn() - - // Default mock implementation returns success + mockCallback = vi.fn() mockCallback.mockResolvedValue(true) }) afterEach(() => { - // Clean up after each test if (refreshTimer) { refreshTimer.stop() } - jest.clearAllTimers() - jest.clearAllMocks() + + vi.clearAllTimers() + vi.clearAllMocks() }) it("should execute callback immediately when started", () => { @@ -50,7 +50,7 @@ describe("RefreshTimer", () => { expect(mockCallback).toHaveBeenCalledTimes(1) // Fast-forward 50 seconds - jest.advanceTimersByTime(50000) + vi.advanceTimersByTime(50000) // Callback should be called again expect(mockCallback).toHaveBeenCalledTimes(2) @@ -72,7 +72,7 @@ describe("RefreshTimer", () => { expect(mockCallback).toHaveBeenCalledTimes(1) // Fast-forward 1 second - jest.advanceTimersByTime(1000) + vi.advanceTimersByTime(1000) // Callback should be called again expect(mockCallback).toHaveBeenCalledTimes(2) @@ -81,7 +81,7 @@ describe("RefreshTimer", () => { await Promise.resolve() // Fast-forward 2 seconds - jest.advanceTimersByTime(2000) + vi.advanceTimersByTime(2000) // Callback should be called again expect(mockCallback).toHaveBeenCalledTimes(3) @@ -103,13 +103,13 @@ describe("RefreshTimer", () => { // Fast-forward through multiple failures to reach max backoff await Promise.resolve() // First attempt - jest.advanceTimersByTime(1000) + vi.advanceTimersByTime(1000) await Promise.resolve() // Second attempt (backoff = 2000ms) - jest.advanceTimersByTime(2000) + vi.advanceTimersByTime(2000) await Promise.resolve() // Third attempt (backoff = 4000ms) - jest.advanceTimersByTime(4000) + vi.advanceTimersByTime(4000) await Promise.resolve() // Fourth attempt (backoff would be 8000ms but max is 5000ms) @@ -132,13 +132,13 @@ describe("RefreshTimer", () => { await Promise.resolve() // Fast-forward 1 second - jest.advanceTimersByTime(1000) + vi.advanceTimersByTime(1000) // Second attempt (succeeds) await Promise.resolve() // Fast-forward 5 seconds - jest.advanceTimersByTime(5000) + vi.advanceTimersByTime(5000) // Third attempt (fails) await Promise.resolve() @@ -173,7 +173,7 @@ describe("RefreshTimer", () => { refreshTimer.stop() // Fast-forward a long time - jest.advanceTimersByTime(1000000) + vi.advanceTimersByTime(1000000) // Callback should only have been called once (the initial call) expect(mockCallback).toHaveBeenCalledTimes(1) @@ -191,10 +191,10 @@ describe("RefreshTimer", () => { // Fast-forward through a few failures await Promise.resolve() - jest.advanceTimersByTime(1000) + vi.advanceTimersByTime(1000) await Promise.resolve() - jest.advanceTimersByTime(2000) + vi.advanceTimersByTime(2000) // Reset the timer refreshTimer.reset() diff --git a/packages/cloud/src/__tests__/RooCodeTelemetryClient.test.ts b/packages/cloud/src/__tests__/RooCodeTelemetryClient.test.ts new file mode 100644 index 0000000000..da8915af9f --- /dev/null +++ b/packages/cloud/src/__tests__/RooCodeTelemetryClient.test.ts @@ -0,0 +1,250 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +// npx vitest run src/__tests__/RooCodeTelemetryClient.test.ts + +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest" + +import { type TelemetryPropertiesProvider, TelemetryEventName } from "@roo-code/types" + +import { RooCodeTelemetryClient } from "../RooCodeTelemetryClient" + +const mockFetch = vi.fn() +global.fetch = mockFetch as any + +describe("RooCodeTelemetryClient", () => { + const getPrivateProperty = (instance: any, propertyName: string): T => { + return instance[propertyName] + } + + let mockAuthService: any + + beforeEach(() => { + vi.clearAllMocks() + + // Create a mock AuthService instead of using the singleton + mockAuthService = { + getSessionToken: vi.fn().mockReturnValue("mock-token"), + getState: vi.fn().mockReturnValue("active-session"), + isAuthenticated: vi.fn().mockReturnValue(true), + hasActiveSession: vi.fn().mockReturnValue(true), + } + + mockFetch.mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({}), + }) + + vi.spyOn(console, "info").mockImplementation(() => {}) + vi.spyOn(console, "error").mockImplementation(() => {}) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe("isEventCapturable", () => { + it("should return true for events not in exclude list", () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.TASK_CREATED)).toBe(true) + expect(isEventCapturable(TelemetryEventName.LLM_COMPLETION)).toBe(true) + expect(isEventCapturable(TelemetryEventName.MODE_SWITCH)).toBe(true) + expect(isEventCapturable(TelemetryEventName.TOOL_USED)).toBe(true) + }) + + it("should return false for events in exclude list", () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.TASK_CONVERSATION_MESSAGE)).toBe(false) + }) + }) + + describe("getEventProperties", () => { + it("should merge provider properties with event properties", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockResolvedValue({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + }), + } + + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { + customProp: "value", + mode: "override", // This should override the provider's mode. + }, + }) + + expect(result).toEqual({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "override", // Event property takes precedence. + customProp: "value", + }) + + expect(mockProvider.getTelemetryProperties).toHaveBeenCalledTimes(1) + }) + + it("should handle errors from provider gracefully", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockRejectedValue(new Error("Provider error")), + } + + const consoleErrorSpy = vi.spyOn(console, "error") + + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining("Error getting telemetry properties: Provider error"), + ) + }) + + it("should return event properties when no provider is set", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + }) + }) + + describe("capture", () => { + it("should not capture events that are not capturable", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + await client.capture({ + event: TelemetryEventName.TASK_CONVERSATION_MESSAGE, // In exclude list. + properties: { test: "value" }, + }) + + expect(mockFetch).not.toHaveBeenCalled() + }) + + it("should not send request when schema validation fails", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }) + + expect(mockFetch).not.toHaveBeenCalled() + expect(console.error).toHaveBeenCalledWith(expect.stringContaining("Invalid telemetry event")) + }) + + it("should send request when event is capturable and validation passes", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + const providerProperties = { + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + } + + const eventProperties = { + taskId: "test-task-id", + } + + const mockValidatedData = { + type: TelemetryEventName.TASK_CREATED, + properties: { + ...providerProperties, + taskId: "test-task-id", + }, + } + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockResolvedValue(providerProperties), + } + + client.setProvider(mockProvider) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: eventProperties, + }) + + expect(mockFetch).toHaveBeenCalledWith( + "https://app.roocode.com/api/events", + expect.objectContaining({ + method: "POST", + body: JSON.stringify(mockValidatedData), + }), + ) + }) + + it("should handle fetch errors gracefully", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + + mockFetch.mockRejectedValue(new Error("Network error")) + + await expect( + client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }), + ).resolves.not.toThrow() + }) + }) + + describe("telemetry state methods", () => { + it("should always return true for isTelemetryEnabled", () => { + const client = new RooCodeTelemetryClient(mockAuthService) + expect(client.isTelemetryEnabled()).toBe(true) + }) + + it("should have empty implementations for updateTelemetryState and shutdown", async () => { + const client = new RooCodeTelemetryClient(mockAuthService) + client.updateTelemetryState(true) + await client.shutdown() + }) + }) +}) diff --git a/packages/cloud/src/__tests__/TelemetryClient.test.ts b/packages/cloud/src/__tests__/TelemetryClient.test.ts new file mode 100644 index 0000000000..fa008dbb34 --- /dev/null +++ b/packages/cloud/src/__tests__/TelemetryClient.test.ts @@ -0,0 +1,250 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +// npx vitest run src/__tests__/TelemetryClient.test.ts + +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest" + +import { type TelemetryPropertiesProvider, TelemetryEventName } from "@roo-code/types" + +import { TelemetryClient } from "../TelemetryClient" + +const mockFetch = vi.fn() +global.fetch = mockFetch as any + +describe("TelemetryClient", () => { + const getPrivateProperty = (instance: any, propertyName: string): T => { + return instance[propertyName] + } + + let mockAuthService: any + + beforeEach(() => { + vi.clearAllMocks() + + // Create a mock AuthService instead of using the singleton + mockAuthService = { + getSessionToken: vi.fn().mockReturnValue("mock-token"), + getState: vi.fn().mockReturnValue("active-session"), + isAuthenticated: vi.fn().mockReturnValue(true), + hasActiveSession: vi.fn().mockReturnValue(true), + } + + mockFetch.mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({}), + }) + + vi.spyOn(console, "info").mockImplementation(() => {}) + vi.spyOn(console, "error").mockImplementation(() => {}) + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe("isEventCapturable", () => { + it("should return true for events not in exclude list", () => { + const client = new TelemetryClient(mockAuthService) + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.TASK_CREATED)).toBe(true) + expect(isEventCapturable(TelemetryEventName.LLM_COMPLETION)).toBe(true) + expect(isEventCapturable(TelemetryEventName.MODE_SWITCH)).toBe(true) + expect(isEventCapturable(TelemetryEventName.TOOL_USED)).toBe(true) + }) + + it("should return false for events in exclude list", () => { + const client = new TelemetryClient(mockAuthService) + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.TASK_CONVERSATION_MESSAGE)).toBe(false) + }) + }) + + describe("getEventProperties", () => { + it("should merge provider properties with event properties", async () => { + const client = new TelemetryClient(mockAuthService) + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockResolvedValue({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + }), + } + + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { + customProp: "value", + mode: "override", // This should override the provider's mode. + }, + }) + + expect(result).toEqual({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "override", // Event property takes precedence. + customProp: "value", + }) + + expect(mockProvider.getTelemetryProperties).toHaveBeenCalledTimes(1) + }) + + it("should handle errors from provider gracefully", async () => { + const client = new TelemetryClient(mockAuthService) + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockRejectedValue(new Error("Provider error")), + } + + const consoleErrorSpy = vi.spyOn(console, "error") + + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining("Error getting telemetry properties: Provider error"), + ) + }) + + it("should return event properties when no provider is set", async () => { + const client = new TelemetryClient(mockAuthService) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + }) + }) + + describe("capture", () => { + it("should not capture events that are not capturable", async () => { + const client = new TelemetryClient(mockAuthService) + + await client.capture({ + event: TelemetryEventName.TASK_CONVERSATION_MESSAGE, // In exclude list. + properties: { test: "value" }, + }) + + expect(mockFetch).not.toHaveBeenCalled() + }) + + it("should not send request when schema validation fails", async () => { + const client = new TelemetryClient(mockAuthService) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }) + + expect(mockFetch).not.toHaveBeenCalled() + expect(console.error).toHaveBeenCalledWith(expect.stringContaining("Invalid telemetry event")) + }) + + it("should send request when event is capturable and validation passes", async () => { + const client = new TelemetryClient(mockAuthService) + + const providerProperties = { + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + } + + const eventProperties = { + taskId: "test-task-id", + } + + const mockValidatedData = { + type: TelemetryEventName.TASK_CREATED, + properties: { + ...providerProperties, + taskId: "test-task-id", + }, + } + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: vi.fn().mockResolvedValue(providerProperties), + } + + client.setProvider(mockProvider) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: eventProperties, + }) + + expect(mockFetch).toHaveBeenCalledWith( + "https://app.roocode.com/api/events", + expect.objectContaining({ + method: "POST", + body: JSON.stringify(mockValidatedData), + }), + ) + }) + + it("should handle fetch errors gracefully", async () => { + const client = new TelemetryClient(mockAuthService) + + mockFetch.mockRejectedValue(new Error("Network error")) + + await expect( + client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }), + ).resolves.not.toThrow() + }) + }) + + describe("telemetry state methods", () => { + it("should always return true for isTelemetryEnabled", () => { + const client = new TelemetryClient(mockAuthService) + expect(client.isTelemetryEnabled()).toBe(true) + }) + + it("should have empty implementations for updateTelemetryState and shutdown", async () => { + const client = new TelemetryClient(mockAuthService) + client.updateTelemetryState(true) + await client.shutdown() + }) + }) +}) diff --git a/packages/cloud/src/index.ts b/packages/cloud/src/index.ts new file mode 100644 index 0000000000..07ea14c784 --- /dev/null +++ b/packages/cloud/src/index.ts @@ -0,0 +1 @@ +export * from "./CloudService" diff --git a/packages/cloud/src/types.ts b/packages/cloud/src/types.ts new file mode 100644 index 0000000000..9c467d9e31 --- /dev/null +++ b/packages/cloud/src/types.ts @@ -0,0 +1,6 @@ +import { CloudUserInfo } from "@roo-code/types" + +export interface CloudServiceCallbacks { + userChanged?: (userInfo: CloudUserInfo | undefined) => void + settingsChanged?: () => void +} diff --git a/packages/cloud/tsconfig.json b/packages/cloud/tsconfig.json new file mode 100644 index 0000000000..f599e2220d --- /dev/null +++ b/packages/cloud/tsconfig.json @@ -0,0 +1,5 @@ +{ + "extends": "@roo-code/config-typescript/vscode-library.json", + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/packages/cloud/vitest.config.ts b/packages/cloud/vitest.config.ts new file mode 100644 index 0000000000..ff37ed3110 --- /dev/null +++ b/packages/cloud/vitest.config.ts @@ -0,0 +1,13 @@ +import { defineConfig } from "vitest/config" + +export default defineConfig({ + test: { + globals: true, + environment: "node", + }, + resolve: { + alias: { + vscode: new URL("./src/__mocks__/vscode.ts", import.meta.url).pathname, + }, + }, +}) diff --git a/packages/config-typescript/vscode-library.json b/packages/config-typescript/vscode-library.json new file mode 100644 index 0000000000..bc09b3db6d --- /dev/null +++ b/packages/config-typescript/vscode-library.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://json.schemastore.org/tsconfig", + "extends": "./base.json", + "compilerOptions": { + "types": ["vitest/globals"], + "outDir": "dist", + "module": "esnext", + "moduleResolution": "Bundler", + "noUncheckedIndexedAccess": false, + "useUnknownInCatchVariables": false + } +} diff --git a/packages/telemetry/eslint.config.mjs b/packages/telemetry/eslint.config.mjs new file mode 100644 index 0000000000..694bf73664 --- /dev/null +++ b/packages/telemetry/eslint.config.mjs @@ -0,0 +1,4 @@ +import { config } from "@roo-code/config-eslint/base" + +/** @type {import("eslint").Linter.Config} */ +export default [...config] diff --git a/packages/telemetry/package.json b/packages/telemetry/package.json new file mode 100644 index 0000000000..229e676415 --- /dev/null +++ b/packages/telemetry/package.json @@ -0,0 +1,25 @@ +{ + "name": "@roo-code/telemetry", + "description": "Roo Code telemetry service and clients.", + "private": true, + "type": "module", + "exports": "./src/index.ts", + "scripts": { + "lint": "eslint src --ext=ts --max-warnings=0", + "check-types": "tsc --noEmit", + "test": "vitest --globals --run", + "clean": "rimraf dist .turbo" + }, + "dependencies": { + "@roo-code/types": "workspace:^", + "posthog-node": "^4.7.0", + "zod": "^3.24.2" + }, + "devDependencies": { + "@roo-code/config-eslint": "workspace:^", + "@roo-code/config-typescript": "workspace:^", + "@types/node": "^22.15.20", + "@types/vscode": "^1.84.0", + "vitest": "^3.1.3" + } +} diff --git a/src/services/telemetry/clients/BaseTelemetryClient.ts b/packages/telemetry/src/BaseTelemetryClient.ts similarity index 90% rename from src/services/telemetry/clients/BaseTelemetryClient.ts rename to packages/telemetry/src/BaseTelemetryClient.ts index 24a486a2ea..ab8ab56f59 100644 --- a/src/services/telemetry/clients/BaseTelemetryClient.ts +++ b/packages/telemetry/src/BaseTelemetryClient.ts @@ -1,6 +1,10 @@ -import { TelemetryEvent, TelemetryEventName } from "@roo-code/types" - -import { TelemetryClient, TelemetryPropertiesProvider, TelemetryEventSubscription } from "../types" +import { + TelemetryEvent, + TelemetryEventName, + TelemetryClient, + TelemetryPropertiesProvider, + TelemetryEventSubscription, +} from "@roo-code/types" export abstract class BaseTelemetryClient implements TelemetryClient { protected providerRef: WeakRef | null = null diff --git a/src/services/telemetry/clients/PostHogTelemetryClient.ts b/packages/telemetry/src/PostHogTelemetryClient.ts similarity index 86% rename from src/services/telemetry/clients/PostHogTelemetryClient.ts rename to packages/telemetry/src/PostHogTelemetryClient.ts index b554d962e3..243176ed45 100644 --- a/src/services/telemetry/clients/PostHogTelemetryClient.ts +++ b/packages/telemetry/src/PostHogTelemetryClient.ts @@ -14,11 +14,11 @@ export class PostHogTelemetryClient extends BaseTelemetryClient { private client: PostHog private distinctId: string = vscode.env.machineId - private constructor(debug = false) { + constructor(debug = false) { super( { type: "exclude", - events: [TelemetryEventName.LLM_COMPLETION], + events: [TelemetryEventName.TASK_MESSAGE, TelemetryEventName.LLM_COMPLETION], }, debug, ) @@ -75,14 +75,4 @@ export class PostHogTelemetryClient extends BaseTelemetryClient { public override async shutdown(): Promise { await this.client.shutdown() } - - private static _instance: PostHogTelemetryClient | null = null - - public static getInstance(): PostHogTelemetryClient { - if (!PostHogTelemetryClient._instance) { - PostHogTelemetryClient._instance = new PostHogTelemetryClient() - } - - return PostHogTelemetryClient._instance - } } diff --git a/src/services/telemetry/TelemetryService.ts b/packages/telemetry/src/TelemetryService.ts similarity index 84% rename from src/services/telemetry/TelemetryService.ts rename to packages/telemetry/src/TelemetryService.ts index cc1248f1b7..4f2427d998 100644 --- a/src/services/telemetry/TelemetryService.ts +++ b/packages/telemetry/src/TelemetryService.ts @@ -1,38 +1,17 @@ -import * as vscode from "vscode" import { ZodError } from "zod" -import { TelemetryEventName } from "@roo-code/types" - -import { logger } from "../../utils/logging" - -import { PostHogTelemetryClient } from "./clients/PostHogTelemetryClient" -import { type TelemetryClient, type TelemetryPropertiesProvider } from "./types" +import { type TelemetryClient, type TelemetryPropertiesProvider, TelemetryEventName } from "@roo-code/types" /** * TelemetryService wrapper class that defers initialization. * This ensures that we only create the various clients after environment * variables are loaded. */ -class TelemetryService { - private clients: TelemetryClient[] = [] - private initialized = false +export class TelemetryService { + constructor(private clients: TelemetryClient[]) {} - /** - * Initialize the telemetry client. This should be called after environment - * variables are loaded. - */ - public async initialize(context: vscode.ExtensionContext): Promise { - if (this.initialized) { - return - } - - this.initialized = true - - try { - this.clients.push(PostHogTelemetryClient.getInstance()) - } catch (error) { - console.warn("Failed to initialize telemetry service:", error) - } + public register(client: TelemetryClient): void { + this.clients.push(client) } /** @@ -44,8 +23,6 @@ class TelemetryService { if (this.isReady) { this.clients.forEach((client) => client.setProvider(provider)) } - - logger.debug("TelemetryService: ClineProvider reference set") } /** @@ -54,7 +31,7 @@ class TelemetryService { * @returns Whether the service is ready to use */ private get isReady(): boolean { - return this.initialized && this.clients.length > 0 + return this.clients.length > 0 } /** @@ -74,7 +51,8 @@ class TelemetryService { * @param eventName The event name to capture * @param properties The event properties */ - public captureEvent(eventName: TelemetryEventName, properties?: any): void { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + public captureEvent(eventName: TelemetryEventName, properties?: Record): void { if (!this.isReady) { return } @@ -197,6 +175,27 @@ class TelemetryService { this.clients.forEach((client) => client.shutdown()) } -} -export const telemetryService = new TelemetryService() + private static _instance: TelemetryService | null = null + + static createInstance(clients: TelemetryClient[] = []) { + if (this._instance) { + throw new Error("TelemetryService instance already created") + } + + this._instance = new TelemetryService(clients) + return this._instance + } + + static get instance() { + if (!this._instance) { + throw new Error("TelemetryService not initialized") + } + + return this._instance + } + + static hasInstance(): boolean { + return this._instance !== null + } +} diff --git a/src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts b/packages/telemetry/src/__tests__/PostHogTelemetryClient.test.ts similarity index 71% rename from src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts rename to packages/telemetry/src/__tests__/PostHogTelemetryClient.test.ts index 89bb16e81a..50d7f5be88 100644 --- a/src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts +++ b/packages/telemetry/src/__tests__/PostHogTelemetryClient.test.ts @@ -1,21 +1,23 @@ -// npx jest src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts +/* eslint-disable @typescript-eslint/no-explicit-any */ +// npx vitest run src/__tests__/PostHogTelemetryClient.test.ts + +import { describe, it, expect, beforeEach, vi } from "vitest" import * as vscode from "vscode" import { PostHog } from "posthog-node" -import { TelemetryEventName } from "@roo-code/types" +import { type TelemetryPropertiesProvider, TelemetryEventName } from "@roo-code/types" -import { TelemetryPropertiesProvider } from "../../types" import { PostHogTelemetryClient } from "../PostHogTelemetryClient" -jest.mock("posthog-node") +vi.mock("posthog-node") -jest.mock("vscode", () => ({ +vi.mock("vscode", () => ({ env: { machineId: "test-machine-id", }, workspace: { - getConfiguration: jest.fn(), + getConfiguration: vi.fn(), }, })) @@ -24,37 +26,29 @@ describe("PostHogTelemetryClient", () => { return instance[propertyName] } - let mockPostHogClient: jest.Mocked + let mockPostHogClient: any beforeEach(() => { - jest.clearAllMocks() + vi.clearAllMocks() mockPostHogClient = { - capture: jest.fn(), - optIn: jest.fn(), - optOut: jest.fn(), - shutdown: jest.fn().mockResolvedValue(undefined), - } as unknown as jest.Mocked - ;(PostHog as unknown as jest.Mock).mockImplementation(() => mockPostHogClient) - - // @ts-ignore - Accessing private static property for testing + capture: vi.fn(), + optIn: vi.fn(), + optOut: vi.fn(), + shutdown: vi.fn().mockResolvedValue(undefined), + } + ;(PostHog as any).mockImplementation(() => mockPostHogClient) + + // @ts-expect-error - Accessing private static property for testing PostHogTelemetryClient._instance = undefined - ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ - get: jest.fn().mockReturnValue("all"), - }) - }) - - describe("getInstance", () => { - it("should return the same instance when called multiple times", () => { - const instance1 = PostHogTelemetryClient.getInstance() - const instance2 = PostHogTelemetryClient.getInstance() - expect(instance1).toBe(instance2) + ;(vscode.workspace.getConfiguration as any).mockReturnValue({ + get: vi.fn().mockReturnValue("all"), }) }) describe("isEventCapturable", () => { it("should return true for events not in exclude list", () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( client, @@ -66,7 +60,7 @@ describe("PostHogTelemetryClient", () => { }) it("should return false for events in exclude list", () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( client, @@ -79,10 +73,10 @@ describe("PostHogTelemetryClient", () => { describe("getEventProperties", () => { it("should merge provider properties with event properties", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() const mockProvider: TelemetryPropertiesProvider = { - getTelemetryProperties: jest.fn().mockResolvedValue({ + getTelemetryProperties: vi.fn().mockResolvedValue({ appVersion: "1.0.0", vscodeVersion: "1.60.0", platform: "darwin", @@ -120,13 +114,13 @@ describe("PostHogTelemetryClient", () => { }) it("should handle errors from provider gracefully", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() const mockProvider: TelemetryPropertiesProvider = { - getTelemetryProperties: jest.fn().mockRejectedValue(new Error("Provider error")), + getTelemetryProperties: vi.fn().mockRejectedValue(new Error("Provider error")), } - const consoleErrorSpy = jest.spyOn(console, "error").mockImplementation() + const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) client.setProvider(mockProvider) const getEventProperties = getPrivateProperty< @@ -147,7 +141,7 @@ describe("PostHogTelemetryClient", () => { }) it("should return event properties when no provider is set", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() const getEventProperties = getPrivateProperty< (event: { event: TelemetryEventName; properties?: Record }) => Promise> @@ -164,7 +158,7 @@ describe("PostHogTelemetryClient", () => { describe("capture", () => { it("should not capture events when telemetry is disabled", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() client.updateTelemetryState(false) await client.capture({ @@ -176,7 +170,7 @@ describe("PostHogTelemetryClient", () => { }) it("should not capture events that are not capturable", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() client.updateTelemetryState(true) await client.capture({ @@ -188,11 +182,11 @@ describe("PostHogTelemetryClient", () => { }) it("should capture events when telemetry is enabled and event is capturable", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() client.updateTelemetryState(true) const mockProvider: TelemetryPropertiesProvider = { - getTelemetryProperties: jest.fn().mockResolvedValue({ + getTelemetryProperties: vi.fn().mockResolvedValue({ appVersion: "1.0.0", vscodeVersion: "1.60.0", platform: "darwin", @@ -222,10 +216,10 @@ describe("PostHogTelemetryClient", () => { describe("updateTelemetryState", () => { it("should enable telemetry when user opts in and global telemetry is enabled", () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() - ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ - get: jest.fn().mockReturnValue("all"), + ;(vscode.workspace.getConfiguration as any).mockReturnValue({ + get: vi.fn().mockReturnValue("all"), }) client.updateTelemetryState(true) @@ -235,10 +229,10 @@ describe("PostHogTelemetryClient", () => { }) it("should disable telemetry when user opts out", () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() - ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ - get: jest.fn().mockReturnValue("all"), + ;(vscode.workspace.getConfiguration as any).mockReturnValue({ + get: vi.fn().mockReturnValue("all"), }) client.updateTelemetryState(false) @@ -248,10 +242,10 @@ describe("PostHogTelemetryClient", () => { }) it("should disable telemetry when global telemetry is disabled, regardless of user opt-in", () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() - ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ - get: jest.fn().mockReturnValue("off"), + ;(vscode.workspace.getConfiguration as any).mockReturnValue({ + get: vi.fn().mockReturnValue("off"), }) client.updateTelemetryState(true) @@ -262,7 +256,7 @@ describe("PostHogTelemetryClient", () => { describe("shutdown", () => { it("should call shutdown on the PostHog client", async () => { - const client = PostHogTelemetryClient.getInstance() + const client = new PostHogTelemetryClient() await client.shutdown() expect(mockPostHogClient.shutdown).toHaveBeenCalled() }) diff --git a/packages/telemetry/src/index.ts b/packages/telemetry/src/index.ts new file mode 100644 index 0000000000..8795ad46a2 --- /dev/null +++ b/packages/telemetry/src/index.ts @@ -0,0 +1,3 @@ +export * from "./BaseTelemetryClient" +export * from "./PostHogTelemetryClient" +export * from "./TelemetryService" diff --git a/packages/telemetry/tsconfig.json b/packages/telemetry/tsconfig.json new file mode 100644 index 0000000000..f599e2220d --- /dev/null +++ b/packages/telemetry/tsconfig.json @@ -0,0 +1,5 @@ +{ + "extends": "@roo-code/config-typescript/vscode-library.json", + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/packages/telemetry/vitest.config.ts b/packages/telemetry/vitest.config.ts new file mode 100644 index 0000000000..f749203bfc --- /dev/null +++ b/packages/telemetry/vitest.config.ts @@ -0,0 +1,8 @@ +import { defineConfig } from "vitest/config" + +export default defineConfig({ + test: { + globals: true, + environment: "node", + }, +}) diff --git a/packages/types/npm/package.json b/packages/types/npm/package.json index a1e46de317..013b894afb 100644 --- a/packages/types/npm/package.json +++ b/packages/types/npm/package.json @@ -1,6 +1,6 @@ { "name": "@roo-code/types", - "version": "1.19.0", + "version": "1.22.0", "description": "TypeScript type definitions for Roo Code.", "publishConfig": { "access": "public", diff --git a/packages/types/src/cloud.ts b/packages/types/src/cloud.ts new file mode 100644 index 0000000000..207b510116 --- /dev/null +++ b/packages/types/src/cloud.ts @@ -0,0 +1,49 @@ +import { z } from "zod" + +export interface CloudUserInfo { + name?: string + email?: string + picture?: string +} + +/** + * Organization Allow List + */ + +export const organizationAllowListSchema = z.object({ + allowAll: z.boolean(), + providers: z.record( + z.object({ + allowAll: z.boolean(), + models: z.array(z.string()).optional(), + }), + ), +}) + +export type OrganizationAllowList = z.infer + +export const ORGANIZATION_ALLOW_ALL: OrganizationAllowList = { + allowAll: true, + providers: {}, +} as const + +/** + * Organization Settings + */ + +export const organizationSettingsSchema = z.object({ + version: z.number(), + defaultSettings: z + .object({ + enableCheckpoints: z.boolean().optional(), + maxOpenTabsContext: z.number().optional(), + maxWorkspaceFiles: z.number().optional(), + showRooIgnoredFiles: z.boolean().optional(), + maxReadFileLine: z.number().optional(), + fuzzyMatchThreshold: z.number().optional(), + }) + .optional(), + allowList: organizationAllowListSchema, +}) + +export type OrganizationSettings = z.infer diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index 8b49dc1d62..8b919d4a30 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -1,5 +1,6 @@ export * from "./api.js" export * from "./codebase-index.js" +export * from "./cloud.js" export * from "./experiment.js" export * from "./global-settings.js" export * from "./history.js" diff --git a/packages/types/src/telemetry.ts b/packages/types/src/telemetry.ts index 78e996f766..967101f85d 100644 --- a/packages/types/src/telemetry.ts +++ b/packages/types/src/telemetry.ts @@ -1,6 +1,7 @@ import { z } from "zod" import { providerNames } from "./provider-settings.js" +import { clineMessageSchema } from "./message.js" /** * TelemetrySetting @@ -20,6 +21,7 @@ export enum TelemetryEventName { TASK_CREATED = "Task Created", TASK_RESTARTED = "Task Reopened", TASK_COMPLETED = "Task Completed", + TASK_MESSAGE = "Task Message", TASK_CONVERSATION_MESSAGE = "Conversation Message", LLM_COMPLETION = "LLM Completion", MODE_SWITCH = "Mode Switched", @@ -87,14 +89,6 @@ export type TelemetryEvent = { * RooCodeTelemetryEvent */ -const completionPropertiesSchema = z.object({ - inputTokens: z.number(), - outputTokens: z.number(), - cacheReadTokens: z.number().optional(), - cacheWriteTokens: z.number().optional(), - cost: z.number().optional(), -}) - export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [ z.object({ type: z.enum([ @@ -116,19 +110,56 @@ export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [ TelemetryEventName.SHELL_INTEGRATION_ERROR, TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR, ]), + properties: telemetryPropertiesSchema, + }), + z.object({ + type: z.literal(TelemetryEventName.TASK_MESSAGE), properties: z.object({ - ...appPropertiesSchema.shape, - ...taskPropertiesSchema.shape, + taskId: z.string(), + message: clineMessageSchema, }), }), z.object({ type: z.literal(TelemetryEventName.LLM_COMPLETION), properties: z.object({ - ...appPropertiesSchema.shape, - ...taskPropertiesSchema.shape, - ...completionPropertiesSchema.shape, + ...telemetryPropertiesSchema.shape, + inputTokens: z.number(), + outputTokens: z.number(), + cacheReadTokens: z.number().optional(), + cacheWriteTokens: z.number().optional(), + cost: z.number().optional(), }), }), ]) export type RooCodeTelemetryEvent = z.infer + +/** + * TelemetryEventSubscription + */ + +export type TelemetryEventSubscription = + | { type: "include"; events: TelemetryEventName[] } + | { type: "exclude"; events: TelemetryEventName[] } + +/** + * TelemetryPropertiesProvider + */ + +export interface TelemetryPropertiesProvider { + getTelemetryProperties(): Promise +} + +/** + * TelemetryClient + */ + +export interface TelemetryClient { + subscription?: TelemetryEventSubscription + + setProvider(provider: TelemetryPropertiesProvider): void + capture(options: TelemetryEvent): Promise + updateTelemetryState(didUserOptIn: boolean): void + isTelemetryEnabled(): boolean + shutdown(): Promise +} diff --git a/packages/types/src/vscode.ts b/packages/types/src/vscode.ts index f12b71cdf7..5dfe1a6397 100644 --- a/packages/types/src/vscode.ts +++ b/packages/types/src/vscode.ts @@ -34,6 +34,7 @@ export const commandIds = [ "mcpButtonClicked", "historyButtonClicked", "popoutButtonClicked", + "accountButtonClicked", "settingsButtonClicked", "openInNewTab", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9340d50d2f..2f15e13c7a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -118,6 +118,34 @@ importers: specifier: ^3.1.3 version: 3.1.3(@types/debug@4.1.12)(@types/node@22.15.20)(jiti@2.4.2)(jsdom@20.0.3)(lightningcss@1.29.2)(tsx@4.19.4)(yaml@2.8.0) + packages/cloud: + dependencies: + '@roo-code/telemetry': + specifier: workspace:^ + version: link:../telemetry + '@roo-code/types': + specifier: workspace:^ + version: link:../types + axios: + specifier: ^1.7.4 + version: 1.9.0 + devDependencies: + '@roo-code/config-eslint': + specifier: workspace:^ + version: link:../config-eslint + '@roo-code/config-typescript': + specifier: workspace:^ + version: link:../config-typescript + '@types/node': + specifier: ^22.15.20 + version: 22.15.20 + '@types/vscode': + specifier: ^1.84.0 + version: 1.100.0 + vitest: + specifier: ^3.1.3 + version: 3.1.3(@types/debug@4.1.12)(@types/node@22.15.20)(jiti@2.4.2)(jsdom@20.0.3)(lightningcss@1.29.2)(tsx@4.19.4)(yaml@2.8.0) + packages/config-eslint: devDependencies: '@eslint/js': @@ -153,6 +181,34 @@ importers: packages/config-typescript: {} + packages/telemetry: + dependencies: + '@roo-code/types': + specifier: workspace:^ + version: link:../types + posthog-node: + specifier: ^4.7.0 + version: 4.17.2 + zod: + specifier: ^3.24.2 + version: 3.24.4 + devDependencies: + '@roo-code/config-eslint': + specifier: workspace:^ + version: link:../config-eslint + '@roo-code/config-typescript': + specifier: workspace:^ + version: link:../config-typescript + '@types/node': + specifier: ^22.15.20 + version: 22.15.20 + '@types/vscode': + specifier: ^1.84.0 + version: 1.100.0 + vitest: + specifier: ^3.1.3 + version: 3.1.3(@types/debug@4.1.12)(@types/node@22.15.20)(jiti@2.4.2)(jsdom@20.0.3)(lightningcss@1.29.2)(tsx@4.19.4)(yaml@2.8.0) + packages/types: dependencies: zod: @@ -204,6 +260,12 @@ importers: '@qdrant/js-client-rest': specifier: ^1.14.0 version: 1.14.0(typescript@5.8.3) + '@roo-code/cloud': + specifier: workspace:^ + version: link:../packages/cloud + '@roo-code/telemetry': + specifier: workspace:^ + version: link:../packages/telemetry '@roo-code/types': specifier: workspace:^ version: link:../packages/types @@ -300,9 +362,6 @@ importers: pkce-challenge: specifier: ^4.1.0 version: 4.1.0 - posthog-node: - specifier: ^4.7.0 - version: 4.17.2 pretty-bytes: specifier: ^6.1.1 version: 6.1.1 diff --git a/src/activate/handleUri.ts b/src/activate/handleUri.ts index 96a24fe6fa..106bcdb311 100644 --- a/src/activate/handleUri.ts +++ b/src/activate/handleUri.ts @@ -1,11 +1,14 @@ import * as vscode from "vscode" +import { CloudService } from "@roo-code/cloud" + import { ClineProvider } from "../core/webview/ClineProvider" export const handleUri = async (uri: vscode.Uri) => { const path = uri.path const query = new URLSearchParams(uri.query.replace(/\+/g, "%2B")) const visibleProvider = ClineProvider.getVisibleInstance() + if (!visibleProvider) { return } @@ -32,6 +35,12 @@ export const handleUri = async (uri: vscode.Uri) => { } break } + case "/auth/clerk/callback": { + const code = query.get("code") + const state = query.get("state") + await CloudService.instance.handleAuthCallback(code, state) + break + } default: break } diff --git a/src/activate/registerCommands.ts b/src/activate/registerCommands.ts index cd76b11f96..3f575b74cb 100644 --- a/src/activate/registerCommands.ts +++ b/src/activate/registerCommands.ts @@ -2,12 +2,12 @@ import * as vscode from "vscode" import delay from "delay" import type { CommandId } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { Package } from "../shared/package" import { getCommand } from "../utils/commands" import { ClineProvider } from "../core/webview/ClineProvider" import { ContextProxy } from "../core/config/ContextProxy" -import { telemetryService } from "../services/telemetry/TelemetryService" import { registerHumanRelayCallback, unregisterHumanRelayCallback, handleHumanRelayResponse } from "./humanRelay" import { handleNewTask } from "./handleTask" @@ -70,6 +70,17 @@ export const registerCommands = (options: RegisterCommandOptions) => { const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOptions): Record => ({ activationCompleted: () => {}, + accountButtonClicked: () => { + const visibleProvider = getVisibleProviderOrLog(outputChannel) + + if (!visibleProvider) { + return + } + + TelemetryService.instance.captureTitleButtonClicked("account") + + visibleProvider.postMessageToWebview({ type: "action", action: "accountButtonClicked" }) + }, plusButtonClicked: async () => { const visibleProvider = getVisibleProviderOrLog(outputChannel) @@ -77,7 +88,7 @@ const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOpt return } - telemetryService.captureTitleButtonClicked("plus") + TelemetryService.instance.captureTitleButtonClicked("plus") await visibleProvider.removeClineFromStack() await visibleProvider.postStateToWebview() @@ -90,7 +101,7 @@ const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOpt return } - telemetryService.captureTitleButtonClicked("mcp") + TelemetryService.instance.captureTitleButtonClicked("mcp") visibleProvider.postMessageToWebview({ type: "action", action: "mcpButtonClicked" }) }, @@ -101,12 +112,12 @@ const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOpt return } - telemetryService.captureTitleButtonClicked("prompts") + TelemetryService.instance.captureTitleButtonClicked("prompts") visibleProvider.postMessageToWebview({ type: "action", action: "promptsButtonClicked" }) }, popoutButtonClicked: () => { - telemetryService.captureTitleButtonClicked("popout") + TelemetryService.instance.captureTitleButtonClicked("popout") return openClineInNewTab({ context, outputChannel }) }, @@ -118,7 +129,7 @@ const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOpt return } - telemetryService.captureTitleButtonClicked("settings") + TelemetryService.instance.captureTitleButtonClicked("settings") visibleProvider.postMessageToWebview({ type: "action", action: "settingsButtonClicked" }) // Also explicitly post the visibility message to trigger scroll reliably @@ -131,7 +142,7 @@ const getCommandsMap = ({ context, outputChannel, provider }: RegisterCommandOpt return } - telemetryService.captureTitleButtonClicked("history") + TelemetryService.instance.captureTitleButtonClicked("history") visibleProvider.postMessageToWebview({ type: "action", action: "historyButtonClicked" }) }, diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 77c510889b..6f283eb0f3 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -2,12 +2,11 @@ import cloneDeep from "clone-deep" import { serializeError } from "serialize-error" import type { ToolName, ClineAsk, ToolProgressStatus } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { defaultModeSlug, getModeBySlug } from "../../shared/modes" import type { ToolParamName, ToolResponse } from "../../shared/tools" -import { telemetryService } from "../../services/telemetry/TelemetryService" - import { fetchInstructionsTool } from "../tools/fetchInstructionsTool" import { listFilesTool } from "../tools/listFilesTool" import { readFileTool } from "../tools/readFileTool" @@ -320,7 +319,7 @@ export async function presentAssistantMessage(cline: Task) { if (!block.partial) { cline.recordToolUsage(block.name) - telemetryService.captureToolUsage(cline.taskId, block.name) + TelemetryService.instance.captureToolUsage(cline.taskId, block.name) } // Validate tool use before execution. @@ -368,7 +367,7 @@ export async function presentAssistantMessage(cline: Task) { await cline.say("user_feedback", text, images) // Track tool repetition in telemetry. - telemetryService.captureConsecutiveMistakeError(cline.taskId) + TelemetryService.instance.captureConsecutiveMistakeError(cline.taskId) } // Return tool result message about the repetition diff --git a/src/core/checkpoints/index.ts b/src/core/checkpoints/index.ts index 68b25b1256..b811b40c48 100644 --- a/src/core/checkpoints/index.ts +++ b/src/core/checkpoints/index.ts @@ -1,6 +1,8 @@ import pWaitFor from "p-wait-for" import * as vscode from "vscode" +import { TelemetryService } from "@roo-code/telemetry" + import { Task } from "../task/Task" import { getWorkspacePath } from "../../utils/path" @@ -10,7 +12,6 @@ import { getApiMetrics } from "../../shared/getApiMetrics" import { DIFF_VIEW_URI_SCHEME } from "../../integrations/editor/DiffViewProvider" -import { telemetryService } from "../../services/telemetry/TelemetryService" import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../services/checkpoints" export function getCheckpointService(cline: Task) { @@ -166,7 +167,7 @@ export async function checkpointSave(cline: Task, force = false) { return } - telemetryService.captureCheckpointCreated(cline.taskId) + TelemetryService.instance.captureCheckpointCreated(cline.taskId) // Start the checkpoint process in the background. return service.saveCheckpoint(`Task: ${cline.taskId}, Time: ${Date.now()}`, { allowEmpty: force }).catch((err) => { @@ -198,7 +199,7 @@ export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: C try { await service.restoreCheckpoint(commitHash) - telemetryService.captureCheckpointRestored(cline.taskId) + TelemetryService.instance.captureCheckpointRestored(cline.taskId) await provider?.postMessageToWebview({ type: "currentCheckpointUpdated", text: commitHash }) if (mode === "restore") { @@ -256,7 +257,7 @@ export async function checkpointDiff(cline: Task, { ts, previousCommitHash, comm return } - telemetryService.captureCheckpointDiffed(cline.taskId) + TelemetryService.instance.captureCheckpointDiffed(cline.taskId) if (!previousCommitHash && mode === "checkpoint") { const previousCheckpoint = cline.clineMessages diff --git a/src/core/condense/__tests__/index.test.ts b/src/core/condense/__tests__/index.test.ts index e3b613f903..81994f3f5b 100644 --- a/src/core/condense/__tests__/index.test.ts +++ b/src/core/condense/__tests__/index.test.ts @@ -1,18 +1,23 @@ +// npx jest core/condense/__tests__/index.test.ts + import { describe, expect, it, jest, beforeEach } from "@jest/globals" + +import { TelemetryService } from "@roo-code/telemetry" + import { ApiHandler } from "../../../api" import { ApiMessage } from "../../task-persistence/apiMessages" import { maybeRemoveImageBlocks } from "../../../api/transform/image-cleaning" import { summarizeConversation, getMessagesSinceLastSummary, N_MESSAGES_TO_KEEP } from "../index" -import { telemetryService } from "../../../services/telemetry/TelemetryService" -// Mock dependencies jest.mock("../../../api/transform/image-cleaning", () => ({ maybeRemoveImageBlocks: jest.fn((messages: ApiMessage[], _apiHandler: ApiHandler) => [...messages]), })) -jest.mock("../../../services/telemetry/TelemetryService", () => ({ - telemetryService: { - captureContextCondensed: jest.fn(), +jest.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureContextCondensed: jest.fn(), + }, }, })) @@ -524,7 +529,7 @@ describe("summarizeConversation with custom settings", () => { jest.clearAllMocks() // Reset telemetry mock - ;(telemetryService.captureContextCondensed as jest.Mock).mockClear() + ;(TelemetryService.instance.captureContextCondensed as jest.Mock).mockClear() // Setup mock API handlers mockMainApiHandler = { @@ -729,7 +734,7 @@ describe("summarizeConversation with custom settings", () => { ) // Verify telemetry was called with custom prompt flag - expect(telemetryService.captureContextCondensed).toHaveBeenCalledWith( + expect(TelemetryService.instance.captureContextCondensed).toHaveBeenCalledWith( taskId, false, true, // usedCustomPrompt @@ -753,7 +758,7 @@ describe("summarizeConversation with custom settings", () => { ) // Verify telemetry was called with custom API handler flag - expect(telemetryService.captureContextCondensed).toHaveBeenCalledWith( + expect(TelemetryService.instance.captureContextCondensed).toHaveBeenCalledWith( taskId, false, false, // usedCustomPrompt @@ -777,7 +782,7 @@ describe("summarizeConversation with custom settings", () => { ) // Verify telemetry was called with both flags - expect(telemetryService.captureContextCondensed).toHaveBeenCalledWith( + expect(TelemetryService.instance.captureContextCondensed).toHaveBeenCalledWith( taskId, true, // isAutomaticTrigger true, // usedCustomPrompt diff --git a/src/core/condense/index.ts b/src/core/condense/index.ts index 58a81f3d22..07c52713ef 100644 --- a/src/core/condense/index.ts +++ b/src/core/condense/index.ts @@ -1,9 +1,11 @@ import Anthropic from "@anthropic-ai/sdk" + +import { TelemetryService } from "@roo-code/telemetry" + import { t } from "../../i18n" import { ApiHandler } from "../../api" import { ApiMessage } from "../task-persistence/apiMessages" import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning" -import { telemetryService } from "../../services/telemetry/TelemetryService" export const N_MESSAGES_TO_KEEP = 3 @@ -88,14 +90,16 @@ export async function summarizeConversation( customCondensingPrompt?: string, condensingApiHandler?: ApiHandler, ): Promise { - telemetryService.captureContextCondensed( + TelemetryService.instance.captureContextCondensed( taskId, isAutomaticTrigger ?? false, !!customCondensingPrompt?.trim(), !!condensingApiHandler, ) + const response: SummarizeResponse = { messages, cost: 0, summary: "" } const messagesToSummarize = getMessagesSinceLastSummary(messages.slice(0, -N_MESSAGES_TO_KEEP)) + if (messagesToSummarize.length <= 1) { const error = messages.length <= N_MESSAGES_TO_KEEP + 1 @@ -103,20 +107,25 @@ export async function summarizeConversation( : t("common:errors.condensed_recently") return { ...response, error } } + const keepMessages = messages.slice(-N_MESSAGES_TO_KEEP) // Check if there's a recent summary in the messages we're keeping const recentSummaryExists = keepMessages.some((message) => message.isSummary) + if (recentSummaryExists) { const error = t("common:errors.condensed_recently") return { ...response, error } } + const finalRequestMessage: Anthropic.MessageParam = { role: "user", content: "Summarize the conversation so far, as described in the prompt instructions.", } + const requestMessages = maybeRemoveImageBlocks([...messagesToSummarize, finalRequestMessage], apiHandler).map( ({ role, content }) => ({ role, content }), ) + // Note: this doesn't need to be a stream, consider using something like apiHandler.completePrompt // Use custom prompt if provided and non-empty, otherwise use the default SUMMARY_PROMPT const promptToUse = customCondensingPrompt?.trim() ? customCondensingPrompt.trim() : SUMMARY_PROMPT @@ -129,7 +138,9 @@ export async function summarizeConversation( console.warn( "Chosen API handler for condensing does not support message creation or is invalid, falling back to main apiHandler.", ) + handlerToUse = apiHandler // Fallback to the main, presumably valid, apiHandler + // Ensure the main apiHandler itself is valid before this point or add another check. if (!handlerToUse || typeof handlerToUse.createMessage !== "function") { // This case should ideally not happen if main apiHandler is always valid. @@ -142,9 +153,11 @@ export async function summarizeConversation( } const stream = handlerToUse.createMessage(promptToUse, requestMessages) + let summary = "" let cost = 0 let outputTokens = 0 + for await (const chunk of stream) { if (chunk.type === "text") { summary += chunk.text @@ -154,28 +167,35 @@ export async function summarizeConversation( outputTokens = chunk.outputTokens ?? 0 } } + summary = summary.trim() + if (summary.length === 0) { const error = t("common:errors.condense_failed") return { ...response, cost, error } } + const summaryMessage: ApiMessage = { role: "assistant", content: summary, ts: keepMessages[0].ts, isSummary: true, } + const newMessages = [...messages.slice(0, -N_MESSAGES_TO_KEEP), summaryMessage, ...keepMessages] // Count the tokens in the context for the next API request // We only estimate the tokens in summaryMesage if outputTokens is 0, otherwise we use outputTokens const systemPromptMessage: ApiMessage = { role: "user", content: systemPrompt } + const contextMessages = outputTokens ? [systemPromptMessage, ...keepMessages] : [systemPromptMessage, summaryMessage, ...keepMessages] + const contextBlocks = contextMessages.flatMap((message) => typeof message.content === "string" ? [{ text: message.content, type: "text" as const }] : message.content, ) + const newContextTokens = outputTokens + (await apiHandler.countTokens(contextBlocks)) if (newContextTokens >= prevContextTokens) { const error = t("common:errors.condense_context_grew") @@ -187,9 +207,11 @@ export async function summarizeConversation( /* Returns the list of all messages since the last summary message, including the summary. Returns all messages if there is no summary. */ export function getMessagesSinceLastSummary(messages: ApiMessage[]): ApiMessage[] { let lastSummaryIndexReverse = [...messages].reverse().findIndex((message) => message.isSummary) + if (lastSummaryIndexReverse === -1) { return messages } + const lastSummaryIndex = messages.length - lastSummaryIndexReverse - 1 return messages.slice(lastSummaryIndex) } diff --git a/src/core/config/ContextProxy.ts b/src/core/config/ContextProxy.ts index 874ed719f2..c4324fbb13 100644 --- a/src/core/config/ContextProxy.ts +++ b/src/core/config/ContextProxy.ts @@ -15,9 +15,9 @@ import { globalSettingsSchema, isSecretStateKey, } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { logger } from "../../utils/logging" -import { telemetryService } from "../../services/telemetry/TelemetryService" type GlobalStateKey = keyof GlobalState type SecretStateKey = keyof SecretState @@ -162,7 +162,7 @@ export class ContextProxy { return globalSettingsSchema.parse(values) } catch (error) { if (error instanceof ZodError) { - telemetryService.captureSchemaValidationError({ schemaName: "GlobalSettings", error }) + TelemetryService.instance.captureSchemaValidationError({ schemaName: "GlobalSettings", error }) } return GLOBAL_SETTINGS_KEYS.reduce((acc, key) => ({ ...acc, [key]: values[key] }), {} as GlobalSettings) @@ -180,7 +180,7 @@ export class ContextProxy { return providerSettingsSchema.parse(values) } catch (error) { if (error instanceof ZodError) { - telemetryService.captureSchemaValidationError({ schemaName: "ProviderSettings", error }) + TelemetryService.instance.captureSchemaValidationError({ schemaName: "ProviderSettings", error }) } return PROVIDER_SETTINGS_KEYS.reduce((acc, key) => ({ ...acc, [key]: values[key] }), {} as ProviderSettings) @@ -248,7 +248,7 @@ export class ContextProxy { return Object.fromEntries(Object.entries(globalSettings).filter(([_, value]) => value !== undefined)) } catch (error) { if (error instanceof ZodError) { - telemetryService.captureSchemaValidationError({ schemaName: "GlobalSettings", error }) + TelemetryService.instance.captureSchemaValidationError({ schemaName: "GlobalSettings", error }) } return undefined diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index d4f2715318..32c0135d3b 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -6,9 +6,9 @@ import { providerSettingsSchema, providerSettingsSchemaDiscriminated, } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { Mode, modes } from "../../shared/modes" -import { telemetryService } from "../../services/telemetry/TelemetryService" const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() }) const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and( @@ -469,7 +469,10 @@ export class ProviderSettingsManager { } } catch (error) { if (error instanceof ZodError) { - telemetryService.captureSchemaValidationError({ schemaName: "ProviderProfiles", error }) + TelemetryService.instance.captureSchemaValidationError({ + schemaName: "ProviderProfiles", + error, + }) } throw new Error(`Failed to read provider profiles from secrets: ${error}`) diff --git a/src/core/config/__tests__/importExport.test.ts b/src/core/config/__tests__/importExport.test.ts index 40def4ebcd..89ac7dfef6 100644 --- a/src/core/config/__tests__/importExport.test.ts +++ b/src/core/config/__tests__/importExport.test.ts @@ -6,6 +6,7 @@ import * as path from "path" import * as vscode from "vscode" import type { ProviderName } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { importSettings, exportSettings } from "../importExport" import { ProviderSettingsManager } from "../ProviderSettingsManager" @@ -41,6 +42,10 @@ describe("importExport", () => { beforeEach(() => { jest.clearAllMocks() + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + mockProviderSettingsManager = { export: jest.fn(), import: jest.fn(), diff --git a/src/core/config/importExport.ts b/src/core/config/importExport.ts index b9caef727e..4830a5f987 100644 --- a/src/core/config/importExport.ts +++ b/src/core/config/importExport.ts @@ -6,8 +6,7 @@ import * as vscode from "vscode" import { z, ZodError } from "zod" import { globalSettingsSchema } from "@roo-code/types" - -import { telemetryService } from "../../services/telemetry/TelemetryService" +import { TelemetryService } from "@roo-code/telemetry" import { ProviderSettingsManager, providerProfilesSchema } from "./ProviderSettingsManager" import { ContextProxy } from "./ContextProxy" @@ -84,7 +83,7 @@ export const importSettings = async ({ providerSettingsManager, contextProxy, cu if (e instanceof ZodError) { error = e.issues.map((issue) => `[${issue.path.join(".")}]: ${issue.message}`).join("\n") - telemetryService.captureSchemaValidationError({ schemaName: "ImportExport", error: e }) + TelemetryService.instance.captureSchemaValidationError({ schemaName: "ImportExport", error: e }) } else if (e instanceof Error) { error = e.message } diff --git a/src/core/sliding-window/__tests__/sliding-window.test.ts b/src/core/sliding-window/__tests__/sliding-window.test.ts index e99e2ed61f..a26ad6b53e 100644 --- a/src/core/sliding-window/__tests__/sliding-window.test.ts +++ b/src/core/sliding-window/__tests__/sliding-window.test.ts @@ -3,6 +3,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import type { ModelInfo } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { BaseProvider } from "../../../api/providers/base-provider" import { ApiMessage } from "../../task-persistence/apiMessages" @@ -41,886 +42,948 @@ class MockApiHandler extends BaseProvider { const mockApiHandler = new MockApiHandler() const taskId = "test-task-id" -/** - * Tests for the truncateConversation function - */ -describe("truncateConversation", () => { - it("should retain the first message", () => { - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - ] - - const result = truncateConversation(messages, 0.5, taskId) - - // With 2 messages after the first, 0.5 fraction means remove 1 message - // But 1 is odd, so it rounds down to 0 (to make it even) - expect(result.length).toBe(3) // First message + 2 remaining messages - expect(result[0]).toEqual(messages[0]) - expect(result[1]).toEqual(messages[1]) - expect(result[2]).toEqual(messages[2]) - }) - - it("should remove the specified fraction of messages (rounded to even number)", () => { - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - { role: "assistant", content: "Fourth message" }, - { role: "user", content: "Fifth message" }, - ] - - // 4 messages excluding first, 0.5 fraction = 2 messages to remove - // 2 is already even, so no rounding needed - const result = truncateConversation(messages, 0.5, taskId) - - expect(result.length).toBe(3) - expect(result[0]).toEqual(messages[0]) - expect(result[1]).toEqual(messages[3]) - expect(result[2]).toEqual(messages[4]) - }) - - it("should round to an even number of messages to remove", () => { - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - { role: "assistant", content: "Fourth message" }, - { role: "user", content: "Fifth message" }, - { role: "assistant", content: "Sixth message" }, - { role: "user", content: "Seventh message" }, - ] - - // 6 messages excluding first, 0.3 fraction = 1.8 messages to remove - // 1.8 rounds down to 1, then to 0 to make it even - const result = truncateConversation(messages, 0.3, taskId) - - expect(result.length).toBe(7) // No messages removed - expect(result).toEqual(messages) - }) - - it("should handle edge case with fracToRemove = 0", () => { - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - ] - - const result = truncateConversation(messages, 0, taskId) - - expect(result).toEqual(messages) - }) - - it("should handle edge case with fracToRemove = 1", () => { - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - { role: "assistant", content: "Fourth message" }, - ] - - // 3 messages excluding first, 1.0 fraction = 3 messages to remove - // But 3 is odd, so it rounds down to 2 to make it even - const result = truncateConversation(messages, 1, taskId) - - expect(result.length).toBe(2) - expect(result[0]).toEqual(messages[0]) - expect(result[1]).toEqual(messages[3]) - }) -}) - -/** - * Tests for the estimateTokenCount function - */ -describe("estimateTokenCount", () => { - it("should return 0 for empty or undefined content", async () => { - expect(await estimateTokenCount([], mockApiHandler)).toBe(0) - // @ts-ignore - Testing with undefined - expect(await estimateTokenCount(undefined, mockApiHandler)).toBe(0) - }) - - it("should estimate tokens for text blocks", async () => { - const content: Array = [ - { type: "text", text: "This is a text block with 36 characters" }, - ] - - // With tiktoken, the exact token count may differ from character-based estimation - // Instead of expecting an exact number, we verify it's a reasonable positive number - const result = await estimateTokenCount(content, mockApiHandler) - expect(result).toBeGreaterThan(0) - - // We can also verify that longer text results in more tokens - const longerContent: Array = [ - { - type: "text", - text: "This is a longer text block with significantly more characters to encode into tokens", - }, - ] - const longerResult = await estimateTokenCount(longerContent, mockApiHandler) - expect(longerResult).toBeGreaterThan(result) - }) - - it("should estimate tokens for image blocks based on data size", async () => { - // Small image - const smallImage: Array = [ - { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "small_dummy_data" } }, - ] - // Larger image with more data - const largerImage: Array = [ - { type: "image", source: { type: "base64", media_type: "image/png", data: "X".repeat(1000) } }, - ] - - // Verify the token count scales with the size of the image data - const smallImageTokens = await estimateTokenCount(smallImage, mockApiHandler) - const largerImageTokens = await estimateTokenCount(largerImage, mockApiHandler) - - // Small image should have some tokens - expect(smallImageTokens).toBeGreaterThan(0) - - // Larger image should have proportionally more tokens - expect(largerImageTokens).toBeGreaterThan(smallImageTokens) - - // Verify the larger image calculation matches our formula including the 50% fudge factor - expect(largerImageTokens).toBe(48) - }) - - it("should estimate tokens for mixed content blocks", async () => { - const content: Array = [ - { type: "text", text: "A text block with 30 characters" }, - { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } }, - { type: "text", text: "Another text with 24 chars" }, - ] - - // We know image tokens calculation should be consistent - const imageTokens = Math.ceil(Math.sqrt("dummy_data".length)) * 1.5 - - // With tiktoken, we can't predict exact text token counts, - // but we can verify the total is greater than just the image tokens - const result = await estimateTokenCount(content, mockApiHandler) - expect(result).toBeGreaterThan(imageTokens) - - // Also test against a version with only the image to verify text adds tokens - const imageOnlyContent: Array = [ - { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } }, - ] - const imageOnlyResult = await estimateTokenCount(imageOnlyContent, mockApiHandler) - expect(result).toBeGreaterThan(imageOnlyResult) - }) - - it("should handle empty text blocks", async () => { - const content: Array = [{ type: "text", text: "" }] - expect(await estimateTokenCount(content, mockApiHandler)).toBe(0) - }) - - it("should handle plain string messages", async () => { - const content = "This is a plain text message" - expect(await estimateTokenCount([{ type: "text", text: content }], mockApiHandler)).toBeGreaterThan(0) - }) -}) - -/** - * Tests for the truncateConversationIfNeeded function - */ -describe("truncateConversationIfNeeded", () => { - const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({ - contextWindow, - supportsPromptCache: true, - maxTokens, - }) - - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - { role: "assistant", content: "Fourth message" }, - { role: "user", content: "Fifth message" }, - ] - - it("should not truncate if tokens are below max tokens threshold", async () => { - const modelInfo = createModelInfo(100000, 30000) - const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10000 - const totalTokens = 70000 - dynamicBuffer - 1 // Just below threshold - buffer - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - // Check the new return type - expect(result).toEqual({ - messages: messagesWithSmallContent, - summary: "", - cost: 0, - prevContextTokens: totalTokens, - }) - }) - - it("should truncate if tokens are above max tokens threshold", async () => { - const modelInfo = createModelInfo(100000, 30000) - const totalTokens = 70001 // Above threshold - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // When truncating, always uses 0.5 fraction - // With 4 messages after the first, 0.5 fraction means remove 2 messages - const expectedMessages = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - expect(result).toEqual({ - messages: expectedMessages, - summary: "", - cost: 0, - prevContextTokens: totalTokens, - }) - }) - - it("should work with non-prompt caching models the same as prompt caching models", async () => { - // The implementation no longer differentiates between prompt caching and non-prompt caching models - const modelInfo1 = createModelInfo(100000, 30000) - const modelInfo2 = createModelInfo(100000, 30000) - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // Test below threshold - const belowThreshold = 69999 - const result1 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: belowThreshold, - contextWindow: modelInfo1.contextWindow, - maxTokens: modelInfo1.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - const result2 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: belowThreshold, - contextWindow: modelInfo2.contextWindow, - maxTokens: modelInfo2.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - expect(result1.messages).toEqual(result2.messages) - expect(result1.summary).toEqual(result2.summary) - expect(result1.cost).toEqual(result2.cost) - expect(result1.prevContextTokens).toEqual(result2.prevContextTokens) - - // Test above threshold - const aboveThreshold = 70001 - const result3 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: aboveThreshold, - contextWindow: modelInfo1.contextWindow, - maxTokens: modelInfo1.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - const result4 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: aboveThreshold, - contextWindow: modelInfo2.contextWindow, - maxTokens: modelInfo2.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - - expect(result3.messages).toEqual(result4.messages) - expect(result3.summary).toEqual(result4.summary) - expect(result3.cost).toEqual(result4.cost) - expect(result3.prevContextTokens).toEqual(result4.prevContextTokens) +describe("Sliding Window", () => { + beforeEach(() => { + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } }) + /** + * Tests for the truncateConversation function + */ + describe("truncateConversation", () => { + it("should retain the first message", () => { + const messages: ApiMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + ] - it("should consider incoming content when deciding to truncate", async () => { - const modelInfo = createModelInfo(100000, 30000) - const maxTokens = 30000 - const availableTokens = modelInfo.contextWindow - maxTokens - - // Test case 1: Small content that won't push us over the threshold - const smallContent = [{ type: "text" as const, text: "Small content" }] - const smallContentTokens = await estimateTokenCount(smallContent, mockApiHandler) - const messagesWithSmallContent: ApiMessage[] = [ - ...messages.slice(0, -1), - { role: messages[messages.length - 1].role, content: smallContent }, - ] - - // Set base tokens so total is well below threshold + buffer even with small content added - const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE - const baseTokensForSmall = availableTokens - smallContentTokens - dynamicBuffer - 10 - const resultWithSmall = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: baseTokensForSmall, - contextWindow: modelInfo.contextWindow, - maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(resultWithSmall).toEqual({ - messages: messagesWithSmallContent, - summary: "", - cost: 0, - prevContextTokens: baseTokensForSmall + smallContentTokens, - }) // No truncation - - // Test case 2: Large content that will push us over the threshold - const largeContent = [ - { - type: "text" as const, - text: "A very large incoming message that would consume a significant number of tokens and push us over the threshold", - }, - ] - const largeContentTokens = await estimateTokenCount(largeContent, mockApiHandler) - const messagesWithLargeContent: ApiMessage[] = [ - ...messages.slice(0, -1), - { role: messages[messages.length - 1].role, content: largeContent }, - ] - - // Set base tokens so we're just below threshold without content, but over with content - const baseTokensForLarge = availableTokens - Math.floor(largeContentTokens / 2) - const resultWithLarge = await truncateConversationIfNeeded({ - messages: messagesWithLargeContent, - totalTokens: baseTokensForLarge, - contextWindow: modelInfo.contextWindow, - maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(resultWithLarge.messages).not.toEqual(messagesWithLargeContent) // Should truncate - expect(resultWithLarge.summary).toBe("") - expect(resultWithLarge.cost).toBe(0) - expect(resultWithLarge.prevContextTokens).toBe(baseTokensForLarge + largeContentTokens) - - // Test case 3: Very large content that will definitely exceed threshold - const veryLargeContent = [{ type: "text" as const, text: "X".repeat(1000) }] - const veryLargeContentTokens = await estimateTokenCount(veryLargeContent, mockApiHandler) - const messagesWithVeryLargeContent: ApiMessage[] = [ - ...messages.slice(0, -1), - { role: messages[messages.length - 1].role, content: veryLargeContent }, - ] + const result = truncateConversation(messages, 0.5, taskId) - // Set base tokens so we're just below threshold without content - const baseTokensForVeryLarge = availableTokens - Math.floor(veryLargeContentTokens / 2) - const resultWithVeryLarge = await truncateConversationIfNeeded({ - messages: messagesWithVeryLargeContent, - totalTokens: baseTokensForVeryLarge, - contextWindow: modelInfo.contextWindow, - maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, + // With 2 messages after the first, 0.5 fraction means remove 1 message + // But 1 is odd, so it rounds down to 0 (to make it even) + expect(result.length).toBe(3) // First message + 2 remaining messages + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[1]) + expect(result[2]).toEqual(messages[2]) }) - expect(resultWithVeryLarge.messages).not.toEqual(messagesWithVeryLargeContent) // Should truncate - expect(resultWithVeryLarge.summary).toBe("") - expect(resultWithVeryLarge.cost).toBe(0) - expect(resultWithVeryLarge.prevContextTokens).toBe(baseTokensForVeryLarge + veryLargeContentTokens) - }) - it("should truncate if tokens are within TOKEN_BUFFER_PERCENTAGE of the threshold", async () => { - const modelInfo = createModelInfo(100000, 30000) - const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10% of 100000 = 10000 - const totalTokens = 70000 - dynamicBuffer + 1 // Just within the dynamic buffer of threshold (70000) - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // When truncating, always uses 0.5 fraction - // With 4 messages after the first, 0.5 fraction means remove 2 messages - const expectedResult = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result).toEqual({ - messages: expectedResult, - summary: "", - cost: 0, - prevContextTokens: totalTokens, - }) - }) - - it("should use summarizeConversation when autoCondenseContext is true and tokens exceed threshold", async () => { - // Mock the summarizeConversation function - const mockSummary = "This is a summary of the conversation" - const mockCost = 0.05 - const mockSummarizeResponse: condenseModule.SummarizeResponse = { - messages: [ + it("should remove the specified fraction of messages (rounded to even number)", () => { + const messages: ApiMessage[] = [ { role: "user", content: "First message" }, - { role: "assistant", content: mockSummary, isSummary: true }, - { role: "user", content: "Last message" }, - ], - summary: mockSummary, - cost: mockCost, - newContextTokens: 100, - } + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + ] - const summarizeSpy = jest - .spyOn(condenseModule, "summarizeConversation") - .mockResolvedValue(mockSummarizeResponse) - - const modelInfo = createModelInfo(100000, 30000) - const totalTokens = 70001 // Above threshold - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: true, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) + // 4 messages excluding first, 0.5 fraction = 2 messages to remove + // 2 is already even, so no rounding needed + const result = truncateConversation(messages, 0.5, taskId) - // Verify summarizeConversation was called with the right parameters - expect(summarizeSpy).toHaveBeenCalledWith( - messagesWithSmallContent, - mockApiHandler, - "System prompt", - taskId, - 70001, - true, - undefined, // customCondensingPrompt - undefined, // condensingApiHandler - ) - - // Verify the result contains the summary information - expect(result).toMatchObject({ - messages: mockSummarizeResponse.messages, - summary: mockSummary, - cost: mockCost, - prevContextTokens: totalTokens, + expect(result.length).toBe(3) + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[3]) + expect(result[2]).toEqual(messages[4]) }) - // newContextTokens might be present, but we don't need to verify its exact value - // Clean up - summarizeSpy.mockRestore() - }) + it("should round to an even number of messages to remove", () => { + const messages: ApiMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + { role: "assistant", content: "Sixth message" }, + { role: "user", content: "Seventh message" }, + ] - it("should fall back to truncateConversation when autoCondenseContext is true but summarization fails", async () => { - // Mock the summarizeConversation function to return an error - const mockSummarizeResponse: condenseModule.SummarizeResponse = { - messages: messages, // Original messages unchanged - summary: "", // Empty summary - cost: 0.01, - error: "Summarization failed", // Error indicates failure - } + // 6 messages excluding first, 0.3 fraction = 1.8 messages to remove + // 1.8 rounds down to 1, then to 0 to make it even + const result = truncateConversation(messages, 0.3, taskId) - const summarizeSpy = jest - .spyOn(condenseModule, "summarizeConversation") - .mockResolvedValue(mockSummarizeResponse) - - const modelInfo = createModelInfo(100000, 30000) - const totalTokens = 70001 // Above threshold - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // When truncating, always uses 0.5 fraction - // With 4 messages after the first, 0.5 fraction means remove 2 messages - const expectedMessages = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: true, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, + expect(result.length).toBe(7) // No messages removed + expect(result).toEqual(messages) }) - // Verify summarizeConversation was called - expect(summarizeSpy).toHaveBeenCalled() - - // Verify it fell back to truncation - expect(result.messages).toEqual(expectedMessages) - expect(result.summary).toBe("") - expect(result.prevContextTokens).toBe(totalTokens) - // The cost might be different than expected, so we don't check it - - // Clean up - summarizeSpy.mockRestore() - }) - - it("should not call summarizeConversation when autoCondenseContext is false", async () => { - // Reset any previous mock calls - jest.clearAllMocks() - const summarizeSpy = jest.spyOn(condenseModule, "summarizeConversation") - - const modelInfo = createModelInfo(100000, 30000) - const totalTokens = 70001 // Above threshold - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // When truncating, always uses 0.5 fraction - // With 4 messages after the first, 0.5 fraction means remove 2 messages - const expectedMessages = [messagesWithSmallContent[0], messagesWithSmallContent[3], messagesWithSmallContent[4]] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 50, // This shouldn't matter since autoCondenseContext is false - systemPrompt: "System prompt", - taskId, - }) + it("should handle edge case with fracToRemove = 0", () => { + const messages: ApiMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + ] - // Verify summarizeConversation was not called - expect(summarizeSpy).not.toHaveBeenCalled() + const result = truncateConversation(messages, 0, taskId) - // Verify it used truncation - expect(result).toEqual({ - messages: expectedMessages, - summary: "", - cost: 0, - prevContextTokens: totalTokens, + expect(result).toEqual(messages) }) - // Clean up - summarizeSpy.mockRestore() - }) - - it("should use summarizeConversation when autoCondenseContext is true and context percent exceeds threshold", async () => { - // Mock the summarizeConversation function - const mockSummary = "This is a summary of the conversation" - const mockCost = 0.05 - const mockSummarizeResponse: condenseModule.SummarizeResponse = { - messages: [ + it("should handle edge case with fracToRemove = 1", () => { + const messages: ApiMessage[] = [ { role: "user", content: "First message" }, - { role: "assistant", content: mockSummary, isSummary: true }, - { role: "user", content: "Last message" }, - ], - summary: mockSummary, - cost: mockCost, - newContextTokens: 100, - } - - const summarizeSpy = jest - .spyOn(condenseModule, "summarizeConversation") - .mockResolvedValue(mockSummarizeResponse) + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + ] - const modelInfo = createModelInfo(100000, 30000) - // Set tokens to be below the allowedTokens threshold but above the percentage threshold - const contextWindow = modelInfo.contextWindow - const totalTokens = 60000 // Below allowedTokens but 60% of context window - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] + // 3 messages excluding first, 1.0 fraction = 3 messages to remove + // But 3 is odd, so it rounds down to 2 to make it even + const result = truncateConversation(messages, 1, taskId) - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, - contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: true, - autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 60% - systemPrompt: "System prompt", - taskId, + expect(result.length).toBe(2) + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[3]) }) + }) - // Verify summarizeConversation was called with the right parameters - expect(summarizeSpy).toHaveBeenCalledWith( - messagesWithSmallContent, - mockApiHandler, - "System prompt", - taskId, - 60000, - true, - undefined, // customCondensingPrompt - undefined, // condensingApiHandler - ) - - // Verify the result contains the summary information - expect(result).toMatchObject({ - messages: mockSummarizeResponse.messages, - summary: mockSummary, - cost: mockCost, - prevContextTokens: totalTokens, + /** + * Tests for the estimateTokenCount function + */ + describe("estimateTokenCount", () => { + it("should return 0 for empty or undefined content", async () => { + expect(await estimateTokenCount([], mockApiHandler)).toBe(0) + // @ts-ignore - Testing with undefined + expect(await estimateTokenCount(undefined, mockApiHandler)).toBe(0) + }) + + it("should estimate tokens for text blocks", async () => { + const content: Array = [ + { type: "text", text: "This is a text block with 36 characters" }, + ] + + // With tiktoken, the exact token count may differ from character-based estimation + // Instead of expecting an exact number, we verify it's a reasonable positive number + const result = await estimateTokenCount(content, mockApiHandler) + expect(result).toBeGreaterThan(0) + + // We can also verify that longer text results in more tokens + const longerContent: Array = [ + { + type: "text", + text: "This is a longer text block with significantly more characters to encode into tokens", + }, + ] + const longerResult = await estimateTokenCount(longerContent, mockApiHandler) + expect(longerResult).toBeGreaterThan(result) + }) + + it("should estimate tokens for image blocks based on data size", async () => { + // Small image + const smallImage: Array = [ + { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "small_dummy_data" } }, + ] + // Larger image with more data + const largerImage: Array = [ + { type: "image", source: { type: "base64", media_type: "image/png", data: "X".repeat(1000) } }, + ] + + // Verify the token count scales with the size of the image data + const smallImageTokens = await estimateTokenCount(smallImage, mockApiHandler) + const largerImageTokens = await estimateTokenCount(largerImage, mockApiHandler) + + // Small image should have some tokens + expect(smallImageTokens).toBeGreaterThan(0) + + // Larger image should have proportionally more tokens + expect(largerImageTokens).toBeGreaterThan(smallImageTokens) + + // Verify the larger image calculation matches our formula including the 50% fudge factor + expect(largerImageTokens).toBe(48) + }) + + it("should estimate tokens for mixed content blocks", async () => { + const content: Array = [ + { type: "text", text: "A text block with 30 characters" }, + { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } }, + { type: "text", text: "Another text with 24 chars" }, + ] + + // We know image tokens calculation should be consistent + const imageTokens = Math.ceil(Math.sqrt("dummy_data".length)) * 1.5 + + // With tiktoken, we can't predict exact text token counts, + // but we can verify the total is greater than just the image tokens + const result = await estimateTokenCount(content, mockApiHandler) + expect(result).toBeGreaterThan(imageTokens) + + // Also test against a version with only the image to verify text adds tokens + const imageOnlyContent: Array = [ + { type: "image", source: { type: "base64", media_type: "image/jpeg", data: "dummy_data" } }, + ] + const imageOnlyResult = await estimateTokenCount(imageOnlyContent, mockApiHandler) + expect(result).toBeGreaterThan(imageOnlyResult) + }) + + it("should handle empty text blocks", async () => { + const content: Array = [{ type: "text", text: "" }] + expect(await estimateTokenCount(content, mockApiHandler)).toBe(0) + }) + + it("should handle plain string messages", async () => { + const content = "This is a plain text message" + expect(await estimateTokenCount([{ type: "text", text: content }], mockApiHandler)).toBeGreaterThan(0) }) - - // Clean up - summarizeSpy.mockRestore() }) - it("should not use summarizeConversation when autoCondenseContext is true but context percent is below threshold", async () => { - // Reset any previous mock calls - jest.clearAllMocks() - const summarizeSpy = jest.spyOn(condenseModule, "summarizeConversation") - - const modelInfo = createModelInfo(100000, 30000) - // Set tokens to be below both the allowedTokens threshold and the percentage threshold - const contextWindow = modelInfo.contextWindow - const totalTokens = 40000 // 40% of context window - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - const result = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens, + /** + * Tests for the truncateConversationIfNeeded function + */ + describe("truncateConversationIfNeeded", () => { + const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({ contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: true, - autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 40% - systemPrompt: "System prompt", - taskId, - }) - - // Verify summarizeConversation was not called - expect(summarizeSpy).not.toHaveBeenCalled() - - // Verify no truncation or summarization occurred - expect(result).toEqual({ - messages: messagesWithSmallContent, - summary: "", - cost: 0, - prevContextTokens: totalTokens, + supportsPromptCache: true, + maxTokens, }) - // Clean up - summarizeSpy.mockRestore() - }) -}) - -/** - * Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded) - */ -describe("getMaxTokens", () => { - // We'll test this indirectly through truncateConversationIfNeeded - const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({ - contextWindow, - supportsPromptCache: true, // Not relevant for getMaxTokens - maxTokens, - }) - - // Reuse across tests for consistency - const messages: ApiMessage[] = [ - { role: "user", content: "First message" }, - { role: "assistant", content: "Second message" }, - { role: "user", content: "Third message" }, - { role: "assistant", content: "Fourth message" }, - { role: "user", content: "Fifth message" }, - ] - - it("should use maxTokens as buffer when specified", async () => { - const modelInfo = createModelInfo(100000, 50000) - // Max tokens = 100000 - 50000 = 50000 - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // Account for the dynamic buffer which is 10% of context window (10,000 tokens) - // Below max tokens and buffer - no truncation - const result1 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 39999, // Well below threshold + dynamic buffer - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result1).toEqual({ - messages: messagesWithSmallContent, - summary: "", - cost: 0, - prevContextTokens: 39999, - }) + const messages: ApiMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + ] - // Above max tokens - truncate - const result2 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 50001, // Above threshold - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, + it("should not truncate if tokens are below max tokens threshold", async () => { + const modelInfo = createModelInfo(100000, 30000) + const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10000 + const totalTokens = 70000 - dynamicBuffer - 1 // Just below threshold - buffer + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + // Check the new return type + expect(result).toEqual({ + messages: messagesWithSmallContent, + summary: "", + cost: 0, + prevContextTokens: totalTokens, + }) + }) + + it("should truncate if tokens are above max tokens threshold", async () => { + const modelInfo = createModelInfo(100000, 30000) + const totalTokens = 70001 // Above threshold + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // When truncating, always uses 0.5 fraction + // With 4 messages after the first, 0.5 fraction means remove 2 messages + const expectedMessages = [ + messagesWithSmallContent[0], + messagesWithSmallContent[3], + messagesWithSmallContent[4], + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + expect(result).toEqual({ + messages: expectedMessages, + summary: "", + cost: 0, + prevContextTokens: totalTokens, + }) + }) + + it("should work with non-prompt caching models the same as prompt caching models", async () => { + // The implementation no longer differentiates between prompt caching and non-prompt caching models + const modelInfo1 = createModelInfo(100000, 30000) + const modelInfo2 = createModelInfo(100000, 30000) + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // Test below threshold + const belowThreshold = 69999 + const result1 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: belowThreshold, + contextWindow: modelInfo1.contextWindow, + maxTokens: modelInfo1.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + const result2 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: belowThreshold, + contextWindow: modelInfo2.contextWindow, + maxTokens: modelInfo2.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + expect(result1.messages).toEqual(result2.messages) + expect(result1.summary).toEqual(result2.summary) + expect(result1.cost).toEqual(result2.cost) + expect(result1.prevContextTokens).toEqual(result2.prevContextTokens) + + // Test above threshold + const aboveThreshold = 70001 + const result3 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: aboveThreshold, + contextWindow: modelInfo1.contextWindow, + maxTokens: modelInfo1.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + const result4 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: aboveThreshold, + contextWindow: modelInfo2.contextWindow, + maxTokens: modelInfo2.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + expect(result3.messages).toEqual(result4.messages) + expect(result3.summary).toEqual(result4.summary) + expect(result3.cost).toEqual(result4.cost) + expect(result3.prevContextTokens).toEqual(result4.prevContextTokens) + }) + + it("should consider incoming content when deciding to truncate", async () => { + const modelInfo = createModelInfo(100000, 30000) + const maxTokens = 30000 + const availableTokens = modelInfo.contextWindow - maxTokens + + // Test case 1: Small content that won't push us over the threshold + const smallContent = [{ type: "text" as const, text: "Small content" }] + const smallContentTokens = await estimateTokenCount(smallContent, mockApiHandler) + const messagesWithSmallContent: ApiMessage[] = [ + ...messages.slice(0, -1), + { role: messages[messages.length - 1].role, content: smallContent }, + ] + + // Set base tokens so total is well below threshold + buffer even with small content added + const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE + const baseTokensForSmall = availableTokens - smallContentTokens - dynamicBuffer - 10 + const resultWithSmall = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: baseTokensForSmall, + contextWindow: modelInfo.contextWindow, + maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(resultWithSmall).toEqual({ + messages: messagesWithSmallContent, + summary: "", + cost: 0, + prevContextTokens: baseTokensForSmall + smallContentTokens, + }) // No truncation + + // Test case 2: Large content that will push us over the threshold + const largeContent = [ + { + type: "text" as const, + text: "A very large incoming message that would consume a significant number of tokens and push us over the threshold", + }, + ] + const largeContentTokens = await estimateTokenCount(largeContent, mockApiHandler) + const messagesWithLargeContent: ApiMessage[] = [ + ...messages.slice(0, -1), + { role: messages[messages.length - 1].role, content: largeContent }, + ] + + // Set base tokens so we're just below threshold without content, but over with content + const baseTokensForLarge = availableTokens - Math.floor(largeContentTokens / 2) + const resultWithLarge = await truncateConversationIfNeeded({ + messages: messagesWithLargeContent, + totalTokens: baseTokensForLarge, + contextWindow: modelInfo.contextWindow, + maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(resultWithLarge.messages).not.toEqual(messagesWithLargeContent) // Should truncate + expect(resultWithLarge.summary).toBe("") + expect(resultWithLarge.cost).toBe(0) + expect(resultWithLarge.prevContextTokens).toBe(baseTokensForLarge + largeContentTokens) + + // Test case 3: Very large content that will definitely exceed threshold + const veryLargeContent = [{ type: "text" as const, text: "X".repeat(1000) }] + const veryLargeContentTokens = await estimateTokenCount(veryLargeContent, mockApiHandler) + const messagesWithVeryLargeContent: ApiMessage[] = [ + ...messages.slice(0, -1), + { role: messages[messages.length - 1].role, content: veryLargeContent }, + ] + + // Set base tokens so we're just below threshold without content + const baseTokensForVeryLarge = availableTokens - Math.floor(veryLargeContentTokens / 2) + const resultWithVeryLarge = await truncateConversationIfNeeded({ + messages: messagesWithVeryLargeContent, + totalTokens: baseTokensForVeryLarge, + contextWindow: modelInfo.contextWindow, + maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(resultWithVeryLarge.messages).not.toEqual(messagesWithVeryLargeContent) // Should truncate + expect(resultWithVeryLarge.summary).toBe("") + expect(resultWithVeryLarge.cost).toBe(0) + expect(resultWithVeryLarge.prevContextTokens).toBe(baseTokensForVeryLarge + veryLargeContentTokens) + }) + + it("should truncate if tokens are within TOKEN_BUFFER_PERCENTAGE of the threshold", async () => { + const modelInfo = createModelInfo(100000, 30000) + const dynamicBuffer = modelInfo.contextWindow * TOKEN_BUFFER_PERCENTAGE // 10% of 100000 = 10000 + const totalTokens = 70000 - dynamicBuffer + 1 // Just within the dynamic buffer of threshold (70000) + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // When truncating, always uses 0.5 fraction + // With 4 messages after the first, 0.5 fraction means remove 2 messages + const expectedResult = [ + messagesWithSmallContent[0], + messagesWithSmallContent[3], + messagesWithSmallContent[4], + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result).toEqual({ + messages: expectedResult, + summary: "", + cost: 0, + prevContextTokens: totalTokens, + }) + }) + + it("should use summarizeConversation when autoCondenseContext is true and tokens exceed threshold", async () => { + // Mock the summarizeConversation function + const mockSummary = "This is a summary of the conversation" + const mockCost = 0.05 + const mockSummarizeResponse: condenseModule.SummarizeResponse = { + messages: [ + { role: "user", content: "First message" }, + { role: "assistant", content: mockSummary, isSummary: true }, + { role: "user", content: "Last message" }, + ], + summary: mockSummary, + cost: mockCost, + newContextTokens: 100, + } + + const summarizeSpy = jest + .spyOn(condenseModule, "summarizeConversation") + .mockResolvedValue(mockSummarizeResponse) + + const modelInfo = createModelInfo(100000, 30000) + const totalTokens = 70001 // Above threshold + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: true, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + // Verify summarizeConversation was called with the right parameters + expect(summarizeSpy).toHaveBeenCalledWith( + messagesWithSmallContent, + mockApiHandler, + "System prompt", + taskId, + 70001, + true, + undefined, // customCondensingPrompt + undefined, // condensingApiHandler + ) + + // Verify the result contains the summary information + expect(result).toMatchObject({ + messages: mockSummarizeResponse.messages, + summary: mockSummary, + cost: mockCost, + prevContextTokens: totalTokens, + }) + // newContextTokens might be present, but we don't need to verify its exact value + + // Clean up + summarizeSpy.mockRestore() + }) + + it("should fall back to truncateConversation when autoCondenseContext is true but summarization fails", async () => { + // Mock the summarizeConversation function to return an error + const mockSummarizeResponse: condenseModule.SummarizeResponse = { + messages: messages, // Original messages unchanged + summary: "", // Empty summary + cost: 0.01, + error: "Summarization failed", // Error indicates failure + } + + const summarizeSpy = jest + .spyOn(condenseModule, "summarizeConversation") + .mockResolvedValue(mockSummarizeResponse) + + const modelInfo = createModelInfo(100000, 30000) + const totalTokens = 70001 // Above threshold + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // When truncating, always uses 0.5 fraction + // With 4 messages after the first, 0.5 fraction means remove 2 messages + const expectedMessages = [ + messagesWithSmallContent[0], + messagesWithSmallContent[3], + messagesWithSmallContent[4], + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: true, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + + // Verify summarizeConversation was called + expect(summarizeSpy).toHaveBeenCalled() + + // Verify it fell back to truncation + expect(result.messages).toEqual(expectedMessages) + expect(result.summary).toBe("") + expect(result.prevContextTokens).toBe(totalTokens) + // The cost might be different than expected, so we don't check it + + // Clean up + summarizeSpy.mockRestore() + }) + + it("should not call summarizeConversation when autoCondenseContext is false", async () => { + // Reset any previous mock calls + jest.clearAllMocks() + const summarizeSpy = jest.spyOn(condenseModule, "summarizeConversation") + + const modelInfo = createModelInfo(100000, 30000) + const totalTokens = 70001 // Above threshold + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // When truncating, always uses 0.5 fraction + // With 4 messages after the first, 0.5 fraction means remove 2 messages + const expectedMessages = [ + messagesWithSmallContent[0], + messagesWithSmallContent[3], + messagesWithSmallContent[4], + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 50, // This shouldn't matter since autoCondenseContext is false + systemPrompt: "System prompt", + taskId, + }) + + // Verify summarizeConversation was not called + expect(summarizeSpy).not.toHaveBeenCalled() + + // Verify it used truncation + expect(result).toEqual({ + messages: expectedMessages, + summary: "", + cost: 0, + prevContextTokens: totalTokens, + }) + + // Clean up + summarizeSpy.mockRestore() + }) + + it("should use summarizeConversation when autoCondenseContext is true and context percent exceeds threshold", async () => { + // Mock the summarizeConversation function + const mockSummary = "This is a summary of the conversation" + const mockCost = 0.05 + const mockSummarizeResponse: condenseModule.SummarizeResponse = { + messages: [ + { role: "user", content: "First message" }, + { role: "assistant", content: mockSummary, isSummary: true }, + { role: "user", content: "Last message" }, + ], + summary: mockSummary, + cost: mockCost, + newContextTokens: 100, + } + + const summarizeSpy = jest + .spyOn(condenseModule, "summarizeConversation") + .mockResolvedValue(mockSummarizeResponse) + + const modelInfo = createModelInfo(100000, 30000) + // Set tokens to be below the allowedTokens threshold but above the percentage threshold + const contextWindow = modelInfo.contextWindow + const totalTokens = 60000 // Below allowedTokens but 60% of context window + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: true, + autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 60% + systemPrompt: "System prompt", + taskId, + }) + + // Verify summarizeConversation was called with the right parameters + expect(summarizeSpy).toHaveBeenCalledWith( + messagesWithSmallContent, + mockApiHandler, + "System prompt", + taskId, + 60000, + true, + undefined, // customCondensingPrompt + undefined, // condensingApiHandler + ) + + // Verify the result contains the summary information + expect(result).toMatchObject({ + messages: mockSummarizeResponse.messages, + summary: mockSummary, + cost: mockCost, + prevContextTokens: totalTokens, + }) + + // Clean up + summarizeSpy.mockRestore() + }) + + it("should not use summarizeConversation when autoCondenseContext is true but context percent is below threshold", async () => { + // Reset any previous mock calls + jest.clearAllMocks() + const summarizeSpy = jest.spyOn(condenseModule, "summarizeConversation") + + const modelInfo = createModelInfo(100000, 30000) + // Set tokens to be below both the allowedTokens threshold and the percentage threshold + const contextWindow = modelInfo.contextWindow + const totalTokens = 40000 // 40% of context window + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + const result = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens, + contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: true, + autoCondenseContextPercent: 50, // Set threshold to 50% - our tokens are at 40% + systemPrompt: "System prompt", + taskId, + }) + + // Verify summarizeConversation was not called + expect(summarizeSpy).not.toHaveBeenCalled() + + // Verify no truncation or summarization occurred + expect(result).toEqual({ + messages: messagesWithSmallContent, + summary: "", + cost: 0, + prevContextTokens: totalTokens, + }) + + // Clean up + summarizeSpy.mockRestore() }) - expect(result2.messages).not.toEqual(messagesWithSmallContent) - expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction - expect(result2.summary).toBe("") - expect(result2.cost).toBe(0) - expect(result2.prevContextTokens).toBe(50001) }) - it("should use 20% of context window as buffer when maxTokens is undefined", async () => { - const modelInfo = createModelInfo(100000, undefined) - // Max tokens = 100000 - (100000 * 0.2) = 80000 - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // Account for the dynamic buffer which is 10% of context window (10,000 tokens) - // Below max tokens and buffer - no truncation - const result1 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 69999, // Well below threshold + dynamic buffer - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result1).toEqual({ - messages: messagesWithSmallContent, - summary: "", - cost: 0, - prevContextTokens: 69999, - }) - - // Above max tokens - truncate - const result2 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 80001, // Above threshold - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, + /** + * Tests for the getMaxTokens function (private but tested through truncateConversationIfNeeded) + */ + describe("getMaxTokens", () => { + // We'll test this indirectly through truncateConversationIfNeeded + const createModelInfo = (contextWindow: number, maxTokens?: number): ModelInfo => ({ + contextWindow, + supportsPromptCache: true, // Not relevant for getMaxTokens + maxTokens, }) - expect(result2.messages).not.toEqual(messagesWithSmallContent) - expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction - expect(result2.summary).toBe("") - expect(result2.cost).toBe(0) - expect(result2.prevContextTokens).toBe(80001) - }) - it("should handle small context windows appropriately", async () => { - const modelInfo = createModelInfo(50000, 10000) - // Max tokens = 50000 - 10000 = 40000 - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // Below max tokens and buffer - no truncation - const result1 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 34999, // Well below threshold + buffer - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result1.messages).toEqual(messagesWithSmallContent) - - // Above max tokens - truncate - const result2 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 40001, // Above threshold - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result2).not.toEqual(messagesWithSmallContent) - expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction - }) + // Reuse across tests for consistency + const messages: ApiMessage[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + ] - it("should handle large context windows appropriately", async () => { - const modelInfo = createModelInfo(200000, 30000) - // Max tokens = 200000 - 30000 = 170000 - - // Create messages with very small content in the last one to avoid token overflow - const messagesWithSmallContent = [...messages.slice(0, -1), { ...messages[messages.length - 1], content: "" }] - - // Account for the dynamic buffer which is 10% of context window (20,000 tokens for this test) - // Below max tokens and buffer - no truncation - const result1 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 149999, // Well below threshold + dynamic buffer - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, - }) - expect(result1.messages).toEqual(messagesWithSmallContent) - - // Above max tokens - truncate - const result2 = await truncateConversationIfNeeded({ - messages: messagesWithSmallContent, - totalTokens: 170001, // Above threshold - contextWindow: modelInfo.contextWindow, - maxTokens: modelInfo.maxTokens, - apiHandler: mockApiHandler, - autoCondenseContext: false, - autoCondenseContextPercent: 100, - systemPrompt: "System prompt", - taskId, + it("should use maxTokens as buffer when specified", async () => { + const modelInfo = createModelInfo(100000, 50000) + // Max tokens = 100000 - 50000 = 50000 + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // Account for the dynamic buffer which is 10% of context window (10,000 tokens) + // Below max tokens and buffer - no truncation + const result1 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 39999, // Well below threshold + dynamic buffer + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result1).toEqual({ + messages: messagesWithSmallContent, + summary: "", + cost: 0, + prevContextTokens: 39999, + }) + + // Above max tokens - truncate + const result2 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 50001, // Above threshold + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result2.messages).not.toEqual(messagesWithSmallContent) + expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction + expect(result2.summary).toBe("") + expect(result2.cost).toBe(0) + expect(result2.prevContextTokens).toBe(50001) + }) + + it("should use 20% of context window as buffer when maxTokens is undefined", async () => { + const modelInfo = createModelInfo(100000, undefined) + // Max tokens = 100000 - (100000 * 0.2) = 80000 + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // Account for the dynamic buffer which is 10% of context window (10,000 tokens) + // Below max tokens and buffer - no truncation + const result1 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 69999, // Well below threshold + dynamic buffer + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result1).toEqual({ + messages: messagesWithSmallContent, + summary: "", + cost: 0, + prevContextTokens: 69999, + }) + + // Above max tokens - truncate + const result2 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 80001, // Above threshold + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result2.messages).not.toEqual(messagesWithSmallContent) + expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction + expect(result2.summary).toBe("") + expect(result2.cost).toBe(0) + expect(result2.prevContextTokens).toBe(80001) + }) + + it("should handle small context windows appropriately", async () => { + const modelInfo = createModelInfo(50000, 10000) + // Max tokens = 50000 - 10000 = 40000 + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // Below max tokens and buffer - no truncation + const result1 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 34999, // Well below threshold + buffer + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result1.messages).toEqual(messagesWithSmallContent) + + // Above max tokens - truncate + const result2 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 40001, // Above threshold + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result2).not.toEqual(messagesWithSmallContent) + expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction + }) + + it("should handle large context windows appropriately", async () => { + const modelInfo = createModelInfo(200000, 30000) + // Max tokens = 200000 - 30000 = 170000 + + // Create messages with very small content in the last one to avoid token overflow + const messagesWithSmallContent = [ + ...messages.slice(0, -1), + { ...messages[messages.length - 1], content: "" }, + ] + + // Account for the dynamic buffer which is 10% of context window (20,000 tokens for this test) + // Below max tokens and buffer - no truncation + const result1 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 149999, // Well below threshold + dynamic buffer + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result1.messages).toEqual(messagesWithSmallContent) + + // Above max tokens - truncate + const result2 = await truncateConversationIfNeeded({ + messages: messagesWithSmallContent, + totalTokens: 170001, // Above threshold + contextWindow: modelInfo.contextWindow, + maxTokens: modelInfo.maxTokens, + apiHandler: mockApiHandler, + autoCondenseContext: false, + autoCondenseContextPercent: 100, + systemPrompt: "System prompt", + taskId, + }) + expect(result2).not.toEqual(messagesWithSmallContent) + expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction }) - expect(result2).not.toEqual(messagesWithSmallContent) - expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction }) }) diff --git a/src/core/sliding-window/index.ts b/src/core/sliding-window/index.ts index 6a0b0f1b27..dc9eaf718d 100644 --- a/src/core/sliding-window/index.ts +++ b/src/core/sliding-window/index.ts @@ -1,8 +1,10 @@ import { Anthropic } from "@anthropic-ai/sdk" + +import { TelemetryService } from "@roo-code/telemetry" + import { ApiHandler } from "../../api" import { summarizeConversation, SummarizeResponse } from "../condense" import { ApiMessage } from "../task-persistence/apiMessages" -import { telemetryService } from "../../services/telemetry/TelemetryService" /** * Default percentage of the context window to use as a buffer when deciding when to truncate @@ -36,7 +38,7 @@ export async function estimateTokenCount( * @returns {ApiMessage[]} The truncated conversation messages. */ export function truncateConversation(messages: ApiMessage[], fracToRemove: number, taskId: string): ApiMessage[] { - telemetryService.captureSlidingWindowTruncation(taskId) + TelemetryService.instance.captureSlidingWindowTruncation(taskId) const truncatedMessages = [messages[0]] const rawMessagesToRemove = Math.floor((messages.length - 1) * fracToRemove) const messagesToRemove = rawMessagesToRemove - (rawMessagesToRemove % 2) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 32a34098bc..ac3b1cb7d8 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -8,18 +8,21 @@ import delay from "delay" import pWaitFor from "p-wait-for" import { serializeError } from "serialize-error" -import type { - ProviderSettings, - TokenUsage, - ToolUsage, - ToolName, - ContextCondense, - ClineAsk, - ClineMessage, - ClineSay, - ToolProgressStatus, - HistoryItem, +import { + type ProviderSettings, + type TokenUsage, + type ToolUsage, + type ToolName, + type ContextCondense, + type ClineAsk, + type ClineMessage, + type ClineSay, + type ToolProgressStatus, + type HistoryItem, + TelemetryEventName, } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" +import { CloudService } from "@roo-code/cloud" // api import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api" @@ -41,7 +44,6 @@ import { UrlContentFetcher } from "../../services/browser/UrlContentFetcher" import { BrowserSession } from "../../services/browser/BrowserSession" import { McpHub } from "../../services/mcp/McpHub" import { McpServerManager } from "../../services/mcp/McpServerManager" -import { telemetryService } from "../../services/telemetry/TelemetryService" import { RepoPerTaskCheckpointService } from "../../services/checkpoints" // integrations @@ -243,9 +245,9 @@ export class Task extends EventEmitter { this.taskNumber = taskNumber if (historyItem) { - telemetryService.captureTaskRestarted(this.taskId) + TelemetryService.instance.captureTaskRestarted(this.taskId) } else { - telemetryService.captureTaskCreated(this.taskId) + TelemetryService.instance.captureTaskCreated(this.taskId) } this.diffStrategy = new MultiSearchReplaceDiffStrategy(this.fuzzyMatchThreshold) @@ -321,6 +323,15 @@ export class Task extends EventEmitter { await this.providerRef.deref()?.postStateToWebview() this.emit("message", { action: "created", message }) await this.saveClineMessages() + + const shouldCaptureMessage = message.partial !== true && CloudService.isEnabled() + + if (shouldCaptureMessage) { + CloudService.instance.captureEvent({ + event: TelemetryEventName.TASK_MESSAGE, + properties: { taskId: this.taskId, message }, + }) + } } public async overwriteClineMessages(newMessages: ClineMessage[]) { @@ -331,6 +342,15 @@ export class Task extends EventEmitter { private async updateClineMessage(partialMessage: ClineMessage) { await this.providerRef.deref()?.postMessageToWebview({ type: "partialMessage", partialMessage }) this.emit("message", { action: "updated", message: partialMessage }) + + const shouldCaptureMessage = partialMessage.partial !== true && CloudService.isEnabled() + + if (shouldCaptureMessage) { + CloudService.instance.captureEvent({ + event: TelemetryEventName.TASK_MESSAGE, + properties: { taskId: this.taskId, message: partialMessage }, + }) + } } private async saveClineMessages() { @@ -1066,7 +1086,7 @@ export class Task extends EventEmitter { await this.say("user_feedback", text, images) // Track consecutive mistake errors in telemetry. - telemetryService.captureConsecutiveMistakeError(this.taskId) + TelemetryService.instance.captureConsecutiveMistakeError(this.taskId) } this.consecutiveMistakeCount = 0 @@ -1125,7 +1145,7 @@ export class Task extends EventEmitter { const finalUserContent = [...parsedUserContent, { type: "text" as const, text: environmentDetails }] await this.addToApiConversationHistory({ role: "user", content: finalUserContent }) - telemetryService.captureConversationMessage(this.taskId, "user") + TelemetryService.instance.captureConversationMessage(this.taskId, "user") // Since we sent off a placeholder api_req_started message to update the // webview while waiting to actually start the API request (to load @@ -1345,7 +1365,7 @@ export class Task extends EventEmitter { cacheReadTokens > 0 || typeof totalCost !== "undefined" ) { - telemetryService.captureLlmCompletion(this.taskId, { + TelemetryService.instance.captureLlmCompletion(this.taskId, { inputTokens, outputTokens, cacheWriteTokens, @@ -1399,7 +1419,7 @@ export class Task extends EventEmitter { content: [{ type: "text", text: assistantMessage }], }) - telemetryService.captureConversationMessage(this.taskId, "assistant") + TelemetryService.instance.captureConversationMessage(this.taskId, "assistant") // NOTE: This comment is here for future reference - this was a // workaround for `userMessageContent` not getting set to true. diff --git a/src/core/task/__tests__/Task.test.ts b/src/core/task/__tests__/Task.test.ts index 79641b56f1..8ed57ffcb3 100644 --- a/src/core/task/__tests__/Task.test.ts +++ b/src/core/task/__tests__/Task.test.ts @@ -1,4 +1,4 @@ -// npx jest src/core/task/__tests__/Task.test.ts +// npx jest core/task/__tests__/Task.test.ts import * as os from "os" import * as path from "path" @@ -7,6 +7,7 @@ import * as vscode from "vscode" import { Anthropic } from "@anthropic-ai/sdk" import type { GlobalState, ProviderSettings, ModelInfo } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { Task } from "../Task" import { ClineProvider } from "../../webview/ClineProvider" @@ -126,10 +127,9 @@ jest.mock("../../environment/getEnvironmentDetails", () => ({ getEnvironmentDetails: jest.fn().mockResolvedValue(""), })) -// Mock RooIgnoreController jest.mock("../../ignore/RooIgnoreController") -// Mock storagePathManager to prevent dynamic import issues +// Mock storagePathManager to prevent dynamic import issues. jest.mock("../../../utils/storage", () => ({ getTaskDirectoryPath: jest .fn() @@ -139,14 +139,12 @@ jest.mock("../../../utils/storage", () => ({ .mockImplementation((globalStoragePath) => Promise.resolve(`${globalStoragePath}/settings`)), })) -// Mock fileExistsAtPath jest.mock("../../../utils/fs", () => ({ fileExistsAtPath: jest.fn().mockImplementation((filePath) => { return filePath.includes("ui_messages.json") || filePath.includes("api_conversation_history.json") }), })) -// Mock fs/promises const mockMessages = [ { ts: Date.now(), @@ -163,6 +161,10 @@ describe("Cline", () => { let mockExtensionContext: vscode.ExtensionContext beforeEach(() => { + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + // Setup mock extension context const storageUri = { fsPath: path.join(os.tmpdir(), "test-storage"), diff --git a/src/core/tools/applyDiffTool.ts b/src/core/tools/applyDiffTool.ts index 19d17c81c4..2c637bc219 100644 --- a/src/core/tools/applyDiffTool.ts +++ b/src/core/tools/applyDiffTool.ts @@ -1,6 +1,8 @@ import path from "path" import fs from "fs/promises" +import { TelemetryService } from "@roo-code/telemetry" + import { ClineSayTool } from "../../shared/ExtensionMessage" import { getReadablePath } from "../../utils/path" import { Task } from "../task/Task" @@ -9,7 +11,6 @@ import { formatResponse } from "../prompts/responses" import { fileExistsAtPath } from "../../utils/fs" import { addLineNumbers } from "../../integrations/misc/extract-text" import { RecordSource } from "../context-tracking/FileContextTrackerTypes" -import { telemetryService } from "../../services/telemetry/TelemetryService" import { unescapeHtmlEntities } from "../../utils/text-normalization" export async function applyDiffTool( @@ -103,7 +104,7 @@ export async function applyDiffTool( const currentCount = (cline.consecutiveMistakeCountForApplyDiff.get(relPath) || 0) + 1 cline.consecutiveMistakeCountForApplyDiff.set(relPath, currentCount) let formattedError = "" - telemetryService.captureDiffApplicationError(cline.taskId, currentCount) + TelemetryService.instance.captureDiffApplicationError(cline.taskId, currentCount) if (diffResult.failParts && diffResult.failParts.length > 0) { for (const failPart of diffResult.failParts) { diff --git a/src/core/tools/attemptCompletionTool.ts b/src/core/tools/attemptCompletionTool.ts index a5e469c77f..08859c98c9 100644 --- a/src/core/tools/attemptCompletionTool.ts +++ b/src/core/tools/attemptCompletionTool.ts @@ -1,5 +1,7 @@ import Anthropic from "@anthropic-ai/sdk" +import { TelemetryService } from "@roo-code/telemetry" + import { Task } from "../task/Task" import { ToolResponse, @@ -12,7 +14,6 @@ import { AskFinishSubTaskApproval, } from "../../shared/tools" import { formatResponse } from "../prompts/responses" -import { telemetryService } from "../../services/telemetry/TelemetryService" import { type ExecuteCommandOptions, executeCommand } from "./executeCommandTool" export async function attemptCompletionTool( @@ -45,7 +46,7 @@ export async function attemptCompletionTool( // we have command string, which means we have the result as well, so finish it (doesnt have to exist yet) await cline.say("completion_result", removeClosingTag("result", result), undefined, false) - telemetryService.captureTaskCompleted(cline.taskId) + TelemetryService.instance.captureTaskCompleted(cline.taskId) cline.emit("taskCompleted", cline.taskId, cline.getTokenUsage(), cline.toolUsage) await cline.ask("command", removeClosingTag("command", command), block.partial).catch(() => {}) @@ -71,7 +72,7 @@ export async function attemptCompletionTool( if (lastMessage && lastMessage.ask !== "command") { // Haven't sent a command message yet so first send completion_result then command. await cline.say("completion_result", result, undefined, false) - telemetryService.captureTaskCompleted(cline.taskId) + TelemetryService.instance.captureTaskCompleted(cline.taskId) cline.emit("taskCompleted", cline.taskId, cline.getTokenUsage(), cline.toolUsage) } @@ -96,7 +97,7 @@ export async function attemptCompletionTool( commandResult = execCommandResult } else { await cline.say("completion_result", result, undefined, false) - telemetryService.captureTaskCompleted(cline.taskId) + TelemetryService.instance.captureTaskCompleted(cline.taskId) cline.emit("taskCompleted", cline.taskId, cline.getTokenUsage(), cline.toolUsage) } diff --git a/src/core/tools/executeCommandTool.ts b/src/core/tools/executeCommandTool.ts index f20b283082..e38d3c74f6 100644 --- a/src/core/tools/executeCommandTool.ts +++ b/src/core/tools/executeCommandTool.ts @@ -4,13 +4,13 @@ import * as path from "path" import delay from "delay" import { CommandExecutionStatus } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { Task } from "../task/Task" import { ToolUse, AskApproval, HandleError, PushToolResult, RemoveClosingTag, ToolResponse } from "../../shared/tools" import { formatResponse } from "../prompts/responses" import { unescapeHtmlEntities } from "../../utils/text-normalization" -import { telemetryService } from "../../services/telemetry/TelemetryService" import { ExitCodeDetails, RooTerminalCallbacks, RooTerminalProcess } from "../../integrations/terminal/types" import { TerminalRegistry } from "../../integrations/terminal/TerminalRegistry" import { Terminal } from "../../integrations/terminal/Terminal" @@ -192,7 +192,7 @@ export async function executeCommand( if (terminalProvider === "vscode") { callbacks.onNoShellIntegration = async (error: string) => { - telemetryService.captureShellIntegrationError(cline.taskId) + TelemetryService.instance.captureShellIntegrationError(cline.taskId) shellIntegrationError = error } } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index a8f0473dd5..5f9f650049 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -9,19 +9,23 @@ import axios from "axios" import pWaitFor from "p-wait-for" import * as vscode from "vscode" -import type { - GlobalState, - ProviderName, - ProviderSettings, - RooCodeSettings, - ProviderSettingsEntry, - TelemetryProperties, - CodeActionId, - CodeActionName, - TerminalActionId, - TerminalActionPromptType, - HistoryItem, +import { + type GlobalState, + type ProviderName, + type ProviderSettings, + type RooCodeSettings, + type ProviderSettingsEntry, + type TelemetryProperties, + type TelemetryPropertiesProvider, + type CodeActionId, + type CodeActionName, + type TerminalActionId, + type TerminalActionPromptType, + type HistoryItem, + ORGANIZATION_ALLOW_ALL, } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" +import { CloudService } from "@roo-code/cloud" import { t } from "../../i18n" import { setPanel } from "../../activate/registerCommands" @@ -53,11 +57,11 @@ import { Task, TaskOptions } from "../task/Task" import { getNonce } from "./getNonce" import { getUri } from "./getUri" import { getSystemPromptFilePath } from "../prompts/sections/custom-system-prompt" -import { TelemetryPropertiesProvider, telemetryService } from "../../services/telemetry" import { getWorkspacePath } from "../../utils/path" import { webviewMessageHandler } from "./webviewMessageHandler" import { WebviewMessage } from "../../shared/WebviewMessage" import { EMBEDDING_MODEL_PROFILES } from "../../shared/embeddingModels" +import { ProfileValidator } from "../../shared/ProfileValidator" /** * https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts @@ -68,6 +72,12 @@ export type ClineProviderEvents = { clineCreated: [cline: Task] } +class OrganizationAllowListViolationError extends Error { + constructor(message: string) { + super(message) + } +} + export class ClineProvider extends EventEmitter implements vscode.WebviewViewProvider, TelemetryPropertiesProvider @@ -114,7 +124,7 @@ export class ClineProvider // Register this provider with the telemetry service to enable it to add // properties like mode and provider. - telemetryService.setProvider(this) + TelemetryService.instance.setProvider(this) this._workspaceTracker = new WorkspaceTracker(this) @@ -288,7 +298,7 @@ export class ClineProvider params: Record, ): Promise { // Capture telemetry for code action usage - telemetryService.captureCodeActionUsed(promptType) + TelemetryService.instance.captureCodeActionUsed(promptType) const visibleProvider = await ClineProvider.getInstance() @@ -314,7 +324,7 @@ export class ClineProvider promptType: TerminalActionPromptType, params: Record, ): Promise { - telemetryService.captureCodeActionUsed(promptType) + TelemetryService.instance.captureCodeActionUsed(promptType) const visibleProvider = await ClineProvider.getInstance() @@ -330,7 +340,15 @@ export class ClineProvider return } - await visibleProvider.initClineWithTask(prompt) + try { + await visibleProvider.initClineWithTask(prompt) + } catch (error) { + if (error instanceof OrganizationAllowListViolationError) { + // Errors from terminal commands seem to get swallowed / ignored. + vscode.window.showErrorMessage(error.message) + } + throw error + } } async resolveWebviewView(webviewView: vscode.WebviewView | vscode.WebviewPanel) { @@ -494,12 +512,17 @@ export class ClineProvider ) { const { apiConfiguration, + organizationAllowList, diffEnabled: enableDiff, enableCheckpoints, fuzzyMatchThreshold, experiments, } = await this.getState() + if (!ProfileValidator.isProfileAllowed(apiConfiguration, organizationAllowList)) { + throw new OrganizationAllowListViolationError(t("common:errors.violated_organization_allowlist")) + } + const cline = new Task({ provider: this, apiConfiguration, @@ -628,7 +651,7 @@ export class ClineProvider "default-src 'none'", `font-src ${webview.cspSource}`, `style-src ${webview.cspSource} 'unsafe-inline' https://* http://${localServerUrl} http://0.0.0.0:${localPort}`, - `img-src ${webview.cspSource} data:`, + `img-src ${webview.cspSource} https://storage.googleapis.com https://img.clerk.com data:`, `media-src ${webview.cspSource}`, `script-src 'unsafe-eval' ${webview.cspSource} https://* https://*.posthog.com http://${localServerUrl} http://0.0.0.0:${localPort} 'nonce-${nonce}'`, `connect-src https://* https://*.posthog.com ws://${localServerUrl} ws://0.0.0.0:${localPort} http://${localServerUrl} http://0.0.0.0:${localPort}`, @@ -713,7 +736,7 @@ export class ClineProvider - +