diff --git a/packages/types/src/codebase-index.ts b/packages/types/src/codebase-index.ts new file mode 100644 index 0000000000..c9443e2fa7 --- /dev/null +++ b/packages/types/src/codebase-index.ts @@ -0,0 +1,37 @@ +import { z } from "zod" + +/** + * CodebaseIndexConfig + */ + +export const codebaseIndexConfigSchema = z.object({ + codebaseIndexEnabled: z.boolean().optional(), + codebaseIndexQdrantUrl: z.string().optional(), + codebaseIndexEmbedderProvider: z.enum(["openai", "ollama"]).optional(), + codebaseIndexEmbedderBaseUrl: z.string().optional(), + codebaseIndexEmbedderModelId: z.string().optional(), +}) + +export type CodebaseIndexConfig = z.infer + +/** + * CodebaseIndexModels + */ + +export const codebaseIndexModelsSchema = z.object({ + openai: z.record(z.string(), z.object({ dimension: z.number() })).optional(), + ollama: z.record(z.string(), z.object({ dimension: z.number() })).optional(), +}) + +export type CodebaseIndexModels = z.infer + +/** + * CdebaseIndexProvider + */ + +export const codebaseIndexProviderSchema = z.object({ + codeIndexOpenAiKey: z.string().optional(), + codeIndexQdrantApiKey: z.string().optional(), +}) + +export type CodebaseIndexProvider = z.infer diff --git a/packages/types/src/experiment.ts b/packages/types/src/experiment.ts new file mode 100644 index 0000000000..6b43327207 --- /dev/null +++ b/packages/types/src/experiment.ts @@ -0,0 +1,26 @@ +import { z } from "zod" + +import type { Keys, Equals, AssertEqual } from "./type-fu.js" + +/** + * ExperimentId + */ + +export const experimentIds = ["autoCondenseContext", "powerSteering"] as const + +export const experimentIdsSchema = z.enum(experimentIds) + +export type ExperimentId = z.infer + +/** + * Experiments + */ + +export const experimentsSchema = z.object({ + autoCondenseContext: z.boolean(), + powerSteering: z.boolean(), +}) + +export type Experiments = z.infer + +type _AssertExperiments = AssertEqual>> diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts new file mode 100644 index 0000000000..d69a77fa53 --- /dev/null +++ b/packages/types/src/global-settings.ts @@ -0,0 +1,269 @@ +import { z } from "zod" + +import type { Keys } from "./type-fu.js" +import { + type ProviderSettings, + PROVIDER_SETTINGS_KEYS, + providerSettingsEntrySchema, + providerSettingsSchema, +} from "./provider-settings.js" +import { historyItemSchema } from "./history.js" +import { codebaseIndexModelsSchema, codebaseIndexConfigSchema } from "./codebase-index.js" +import { experimentsSchema } from "./experiment.js" +import { telemetrySettingsSchema } from "./telemetry.js" +import { modeConfigSchema } from "./mode.js" +import { customModePromptsSchema, customSupportPromptsSchema } from "./mode.js" +import { languagesSchema } from "./vscode.js" + +/** + * GlobalSettings + */ + +export const globalSettingsSchema = z.object({ + currentApiConfigName: z.string().optional(), + listApiConfigMeta: z.array(providerSettingsEntrySchema).optional(), + pinnedApiConfigs: z.record(z.string(), z.boolean()).optional(), + + lastShownAnnouncementId: z.string().optional(), + customInstructions: z.string().optional(), + taskHistory: z.array(historyItemSchema).optional(), + + condensingApiConfigId: z.string().optional(), + customCondensingPrompt: z.string().optional(), + + autoApprovalEnabled: z.boolean().optional(), + alwaysAllowReadOnly: z.boolean().optional(), + alwaysAllowReadOnlyOutsideWorkspace: z.boolean().optional(), + codebaseIndexModels: codebaseIndexModelsSchema.optional(), + codebaseIndexConfig: codebaseIndexConfigSchema.optional(), + alwaysAllowWrite: z.boolean().optional(), + alwaysAllowWriteOutsideWorkspace: z.boolean().optional(), + writeDelayMs: z.number().optional(), + alwaysAllowBrowser: z.boolean().optional(), + alwaysApproveResubmit: z.boolean().optional(), + requestDelaySeconds: z.number().optional(), + alwaysAllowMcp: z.boolean().optional(), + alwaysAllowModeSwitch: z.boolean().optional(), + alwaysAllowSubtasks: z.boolean().optional(), + alwaysAllowExecute: z.boolean().optional(), + allowedCommands: z.array(z.string()).optional(), + allowedMaxRequests: z.number().nullish(), + autoCondenseContextPercent: z.number().optional(), + + browserToolEnabled: z.boolean().optional(), + browserViewportSize: z.string().optional(), + screenshotQuality: z.number().optional(), + remoteBrowserEnabled: z.boolean().optional(), + remoteBrowserHost: z.string().optional(), + cachedChromeHostUrl: z.string().optional(), + + enableCheckpoints: z.boolean().optional(), + + ttsEnabled: z.boolean().optional(), + ttsSpeed: z.number().optional(), + soundEnabled: z.boolean().optional(), + soundVolume: z.number().optional(), + + maxOpenTabsContext: z.number().optional(), + maxWorkspaceFiles: z.number().optional(), + showRooIgnoredFiles: z.boolean().optional(), + maxReadFileLine: z.number().optional(), + + terminalOutputLineLimit: z.number().optional(), + terminalShellIntegrationTimeout: z.number().optional(), + terminalShellIntegrationDisabled: z.boolean().optional(), + terminalCommandDelay: z.number().optional(), + terminalPowershellCounter: z.boolean().optional(), + terminalZshClearEolMark: z.boolean().optional(), + terminalZshOhMy: z.boolean().optional(), + terminalZshP10k: z.boolean().optional(), + terminalZdotdir: z.boolean().optional(), + terminalCompressProgressBar: z.boolean().optional(), + + rateLimitSeconds: z.number().optional(), + diffEnabled: z.boolean().optional(), + fuzzyMatchThreshold: z.number().optional(), + experiments: experimentsSchema.optional(), + + language: languagesSchema.optional(), + + telemetrySetting: telemetrySettingsSchema.optional(), + + mcpEnabled: z.boolean().optional(), + enableMcpServerCreation: z.boolean().optional(), + + mode: z.string().optional(), + modeApiConfigs: z.record(z.string(), z.string()).optional(), + customModes: z.array(modeConfigSchema).optional(), + customModePrompts: customModePromptsSchema.optional(), + customSupportPrompts: customSupportPromptsSchema.optional(), + enhancementApiConfigId: z.string().optional(), + historyPreviewCollapsed: z.boolean().optional(), +}) + +export type GlobalSettings = z.infer + +type GlobalSettingsRecord = Record, undefined> + +const globalSettingsRecord: GlobalSettingsRecord = { + codebaseIndexModels: undefined, + codebaseIndexConfig: undefined, + currentApiConfigName: undefined, + listApiConfigMeta: undefined, + pinnedApiConfigs: undefined, + + lastShownAnnouncementId: undefined, + customInstructions: undefined, + taskHistory: undefined, + + condensingApiConfigId: undefined, + customCondensingPrompt: undefined, + + autoApprovalEnabled: undefined, + alwaysAllowReadOnly: undefined, + alwaysAllowReadOnlyOutsideWorkspace: undefined, + alwaysAllowWrite: undefined, + alwaysAllowWriteOutsideWorkspace: undefined, + writeDelayMs: undefined, + alwaysAllowBrowser: undefined, + alwaysApproveResubmit: undefined, + requestDelaySeconds: undefined, + alwaysAllowMcp: undefined, + alwaysAllowModeSwitch: undefined, + alwaysAllowSubtasks: undefined, + alwaysAllowExecute: undefined, + allowedCommands: undefined, + allowedMaxRequests: undefined, + autoCondenseContextPercent: undefined, + + browserToolEnabled: undefined, + browserViewportSize: undefined, + screenshotQuality: undefined, + remoteBrowserEnabled: undefined, + remoteBrowserHost: undefined, + + enableCheckpoints: undefined, + + ttsEnabled: undefined, + ttsSpeed: undefined, + soundEnabled: undefined, + soundVolume: undefined, + + maxOpenTabsContext: undefined, + maxWorkspaceFiles: undefined, + showRooIgnoredFiles: undefined, + maxReadFileLine: undefined, + + terminalOutputLineLimit: undefined, + terminalShellIntegrationTimeout: undefined, + terminalShellIntegrationDisabled: undefined, + terminalCommandDelay: undefined, + terminalPowershellCounter: undefined, + terminalZshClearEolMark: undefined, + terminalZshOhMy: undefined, + terminalZshP10k: undefined, + terminalZdotdir: undefined, + terminalCompressProgressBar: undefined, + + rateLimitSeconds: undefined, + diffEnabled: undefined, + fuzzyMatchThreshold: undefined, + experiments: undefined, + + language: undefined, + + telemetrySetting: undefined, + + mcpEnabled: undefined, + enableMcpServerCreation: undefined, + + mode: undefined, + modeApiConfigs: undefined, + customModes: undefined, + customModePrompts: undefined, + customSupportPrompts: undefined, + enhancementApiConfigId: undefined, + cachedChromeHostUrl: undefined, + historyPreviewCollapsed: undefined, +} + +export const GLOBAL_SETTINGS_KEYS = Object.keys(globalSettingsRecord) as Keys[] + +/** + * RooCodeSettings + */ + +export const rooCodeSettingsSchema = providerSettingsSchema.merge(globalSettingsSchema) + +export type RooCodeSettings = GlobalSettings & ProviderSettings + +/** + * SecretState + */ + +export type SecretState = Pick< + ProviderSettings, + | "apiKey" + | "glamaApiKey" + | "openRouterApiKey" + | "awsAccessKey" + | "awsSecretKey" + | "awsSessionToken" + | "openAiApiKey" + | "geminiApiKey" + | "openAiNativeApiKey" + | "deepSeekApiKey" + | "mistralApiKey" + | "unboundApiKey" + | "requestyApiKey" + | "xaiApiKey" + | "groqApiKey" + | "chutesApiKey" + | "litellmApiKey" + | "codeIndexOpenAiKey" + | "codeIndexQdrantApiKey" +> + +export type CodeIndexSecrets = "codeIndexOpenAiKey" | "codeIndexQdrantApiKey" + +type SecretStateRecord = Record, undefined> + +const secretStateRecord: SecretStateRecord = { + apiKey: undefined, + glamaApiKey: undefined, + openRouterApiKey: undefined, + awsAccessKey: undefined, + awsSecretKey: undefined, + awsSessionToken: undefined, + openAiApiKey: undefined, + geminiApiKey: undefined, + openAiNativeApiKey: undefined, + deepSeekApiKey: undefined, + mistralApiKey: undefined, + unboundApiKey: undefined, + requestyApiKey: undefined, + xaiApiKey: undefined, + groqApiKey: undefined, + chutesApiKey: undefined, + litellmApiKey: undefined, + codeIndexOpenAiKey: undefined, + codeIndexQdrantApiKey: undefined, +} + +export const SECRET_STATE_KEYS = Object.keys(secretStateRecord) as Keys[] + +export const isSecretStateKey = (key: string): key is Keys => + SECRET_STATE_KEYS.includes(key as Keys) + +/** + * GlobalState + */ + +export type GlobalState = Omit> + +export const GLOBAL_STATE_KEYS = [...GLOBAL_SETTINGS_KEYS, ...PROVIDER_SETTINGS_KEYS].filter( + (key: Keys) => !SECRET_STATE_KEYS.includes(key as Keys), +) as Keys[] + +export const isGlobalStateKey = (key: string): key is Keys => + GLOBAL_STATE_KEYS.includes(key as Keys) diff --git a/packages/types/src/history.ts b/packages/types/src/history.ts new file mode 100644 index 0000000000..8c75024879 --- /dev/null +++ b/packages/types/src/history.ts @@ -0,0 +1,21 @@ +import { z } from "zod" + +/** + * HistoryItem + */ + +export const historyItemSchema = z.object({ + id: z.string(), + number: z.number(), + ts: z.number(), + task: z.string(), + tokensIn: z.number(), + tokensOut: z.number(), + cacheWrites: z.number().optional(), + cacheReads: z.number().optional(), + totalCost: z.number(), + size: z.number().optional(), + workspace: z.string().optional(), +}) + +export type HistoryItem = z.infer diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index b3656fb63a..8b49dc1d62 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -1,2 +1,15 @@ -export * from "./types.js" export * from "./api.js" +export * from "./codebase-index.js" +export * from "./experiment.js" +export * from "./global-settings.js" +export * from "./history.js" +export * from "./ipc.js" +export * from "./message.js" +export * from "./mode.js" +export * from "./model.js" +export * from "./provider-settings.js" +export * from "./telemetry.js" +export * from "./terminal.js" +export * from "./tool.js" +export * from "./type-fu.js" +export * from "./vscode.js" diff --git a/packages/types/src/ipc.ts b/packages/types/src/ipc.ts new file mode 100644 index 0000000000..aa35e194a9 --- /dev/null +++ b/packages/types/src/ipc.ts @@ -0,0 +1,183 @@ +import { z } from "zod" + +import { clineMessageSchema, tokenUsageSchema } from "./message.js" +import { toolNamesSchema, toolUsageSchema } from "./tool.js" +import { rooCodeSettingsSchema } from "./global-settings.js" + +/** + * RooCodeEvent + */ + +export enum RooCodeEventName { + Message = "message", + TaskCreated = "taskCreated", + TaskStarted = "taskStarted", + TaskModeSwitched = "taskModeSwitched", + TaskPaused = "taskPaused", + TaskUnpaused = "taskUnpaused", + TaskAskResponded = "taskAskResponded", + TaskAborted = "taskAborted", + TaskSpawned = "taskSpawned", + TaskCompleted = "taskCompleted", + TaskTokenUsageUpdated = "taskTokenUsageUpdated", + TaskToolFailed = "taskToolFailed", +} + +export const rooCodeEventsSchema = z.object({ + [RooCodeEventName.Message]: z.tuple([ + z.object({ + taskId: z.string(), + action: z.union([z.literal("created"), z.literal("updated")]), + message: clineMessageSchema, + }), + ]), + [RooCodeEventName.TaskCreated]: z.tuple([z.string()]), + [RooCodeEventName.TaskStarted]: z.tuple([z.string()]), + [RooCodeEventName.TaskModeSwitched]: z.tuple([z.string(), z.string()]), + [RooCodeEventName.TaskPaused]: z.tuple([z.string()]), + [RooCodeEventName.TaskUnpaused]: z.tuple([z.string()]), + [RooCodeEventName.TaskAskResponded]: z.tuple([z.string()]), + [RooCodeEventName.TaskAborted]: z.tuple([z.string()]), + [RooCodeEventName.TaskSpawned]: z.tuple([z.string(), z.string()]), + [RooCodeEventName.TaskCompleted]: z.tuple([z.string(), tokenUsageSchema, toolUsageSchema]), + [RooCodeEventName.TaskTokenUsageUpdated]: z.tuple([z.string(), tokenUsageSchema]), + [RooCodeEventName.TaskToolFailed]: z.tuple([z.string(), toolNamesSchema, z.string()]), +}) + +export type RooCodeEvents = z.infer + +/** + * Ack + */ + +export const ackSchema = z.object({ + clientId: z.string(), + pid: z.number(), + ppid: z.number(), +}) + +export type Ack = z.infer + +/** + * TaskCommand + */ + +export enum TaskCommandName { + StartNewTask = "StartNewTask", + CancelTask = "CancelTask", + CloseTask = "CloseTask", +} + +export const taskCommandSchema = z.discriminatedUnion("commandName", [ + z.object({ + commandName: z.literal(TaskCommandName.StartNewTask), + data: z.object({ + configuration: rooCodeSettingsSchema, + text: z.string(), + images: z.array(z.string()).optional(), + newTab: z.boolean().optional(), + }), + }), + z.object({ + commandName: z.literal(TaskCommandName.CancelTask), + data: z.string(), + }), + z.object({ + commandName: z.literal(TaskCommandName.CloseTask), + data: z.string(), + }), +]) + +export type TaskCommand = z.infer + +/** + * TaskEvent + */ + +export const taskEventSchema = z.discriminatedUnion("eventName", [ + z.object({ + eventName: z.literal(RooCodeEventName.Message), + payload: rooCodeEventsSchema.shape[RooCodeEventName.Message], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskCreated), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskCreated], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskStarted), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskStarted], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskModeSwitched), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskModeSwitched], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskPaused), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskPaused], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskUnpaused), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskUnpaused], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskAskResponded), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskAskResponded], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskAborted), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskAborted], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskSpawned), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskSpawned], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskCompleted), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskCompleted], + }), + z.object({ + eventName: z.literal(RooCodeEventName.TaskTokenUsageUpdated), + payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskTokenUsageUpdated], + }), +]) + +export type TaskEvent = z.infer + +/** + * IpcMessage + */ + +export enum IpcMessageType { + Connect = "Connect", + Disconnect = "Disconnect", + Ack = "Ack", + TaskCommand = "TaskCommand", + TaskEvent = "TaskEvent", +} + +export enum IpcOrigin { + Client = "client", + Server = "server", +} + +export const ipcMessageSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.literal(IpcMessageType.Ack), + origin: z.literal(IpcOrigin.Server), + data: ackSchema, + }), + z.object({ + type: z.literal(IpcMessageType.TaskCommand), + origin: z.literal(IpcOrigin.Client), + clientId: z.string(), + data: taskCommandSchema, + }), + z.object({ + type: z.literal(IpcMessageType.TaskEvent), + origin: z.literal(IpcOrigin.Server), + relayClientId: z.string().optional(), + data: taskEventSchema, + }), +]) + +export type IpcMessage = z.infer diff --git a/packages/types/src/message.ts b/packages/types/src/message.ts new file mode 100644 index 0000000000..e870e8d707 --- /dev/null +++ b/packages/types/src/message.ts @@ -0,0 +1,118 @@ +import { z } from "zod" + +/** + * ClineAsk + */ + +export const clineAsks = [ + "followup", + "command", + "command_output", + "completion_result", + "tool", + "api_req_failed", + "resume_task", + "resume_completed_task", + "mistake_limit_reached", + "browser_action_launch", + "use_mcp_server", + "auto_approval_max_req_reached", +] as const + +export const clineAskSchema = z.enum(clineAsks) + +export type ClineAsk = z.infer + +/** + * ClineSay + */ + +export const clineSays = [ + "error", + "api_req_started", + "api_req_finished", + "api_req_retried", + "api_req_retry_delayed", + "api_req_deleted", + "text", + "reasoning", + "completion_result", + "user_feedback", + "user_feedback_diff", + "command_output", + "shell_integration_warning", + "browser_action", + "browser_action_result", + "mcp_server_request_started", + "mcp_server_response", + "subtask_result", + "checkpoint_saved", + "rooignore_error", + "diff_error", + "condense_context", + "codebase_search_result", +] as const + +export const clineSaySchema = z.enum(clineSays) + +export type ClineSay = z.infer + +/** + * ToolProgressStatus + */ + +export const toolProgressStatusSchema = z.object({ + icon: z.string().optional(), + text: z.string().optional(), +}) + +export type ToolProgressStatus = z.infer + +/** + * ContextCondense + */ + +export const contextCondenseSchema = z.object({ + cost: z.number(), + prevContextTokens: z.number(), + newContextTokens: z.number(), + summary: z.string(), +}) + +export type ContextCondense = z.infer + +/** + * ClineMessage + */ + +export const clineMessageSchema = z.object({ + ts: z.number(), + type: z.union([z.literal("ask"), z.literal("say")]), + ask: clineAskSchema.optional(), + say: clineSaySchema.optional(), + text: z.string().optional(), + images: z.array(z.string()).optional(), + partial: z.boolean().optional(), + reasoning: z.string().optional(), + conversationHistoryIndex: z.number().optional(), + checkpoint: z.record(z.string(), z.unknown()).optional(), + progressStatus: toolProgressStatusSchema.optional(), + contextCondense: contextCondenseSchema.optional(), +}) + +export type ClineMessage = z.infer + +/** + * TokenUsage + */ + +export const tokenUsageSchema = z.object({ + totalTokensIn: z.number(), + totalTokensOut: z.number(), + totalCacheWrites: z.number().optional(), + totalCacheReads: z.number().optional(), + totalCost: z.number(), + contextTokens: z.number(), +}) + +export type TokenUsage = z.infer diff --git a/packages/types/src/mode.ts b/packages/types/src/mode.ts new file mode 100644 index 0000000000..dfe95f8d7e --- /dev/null +++ b/packages/types/src/mode.ts @@ -0,0 +1,128 @@ +import { z } from "zod" + +import { toolGroupsSchema } from "./tool.js" + +/** + * GroupOptions + */ + +export const groupOptionsSchema = z.object({ + fileRegex: z + .string() + .optional() + .refine( + (pattern) => { + if (!pattern) { + return true // Optional, so empty is valid. + } + + try { + new RegExp(pattern) + return true + } catch { + return false + } + }, + { message: "Invalid regular expression pattern" }, + ), + description: z.string().optional(), +}) + +export type GroupOptions = z.infer + +/** + * GroupEntry + */ + +export const groupEntrySchema = z.union([toolGroupsSchema, z.tuple([toolGroupsSchema, groupOptionsSchema])]) + +export type GroupEntry = z.infer + +/** + * ModeConfig + */ + +const groupEntryArraySchema = z.array(groupEntrySchema).refine( + (groups) => { + const seen = new Set() + + return groups.every((group) => { + // For tuples, check the group name (first element). + const groupName = Array.isArray(group) ? group[0] : group + + if (seen.has(groupName)) { + return false + } + + seen.add(groupName) + return true + }) + }, + { message: "Duplicate groups are not allowed" }, +) + +export const modeConfigSchema = z.object({ + slug: z.string().regex(/^[a-zA-Z0-9-]+$/, "Slug must contain only letters numbers and dashes"), + name: z.string().min(1, "Name is required"), + roleDefinition: z.string().min(1, "Role definition is required"), + whenToUse: z.string().optional(), + customInstructions: z.string().optional(), + groups: groupEntryArraySchema, + source: z.enum(["global", "project"]).optional(), +}) + +export type ModeConfig = z.infer + +/** + * CustomModesSettings + */ + +export const customModesSettingsSchema = z.object({ + customModes: z.array(modeConfigSchema).refine( + (modes) => { + const slugs = new Set() + + return modes.every((mode) => { + if (slugs.has(mode.slug)) { + return false + } + + slugs.add(mode.slug) + return true + }) + }, + { + message: "Duplicate mode slugs are not allowed", + }, + ), +}) + +export type CustomModesSettings = z.infer + +/** + * PromptComponent + */ + +export const promptComponentSchema = z.object({ + roleDefinition: z.string().optional(), + whenToUse: z.string().optional(), + customInstructions: z.string().optional(), +}) + +export type PromptComponent = z.infer + +/** + * CustomModePrompts + */ + +export const customModePromptsSchema = z.record(z.string(), promptComponentSchema.optional()) + +export type CustomModePrompts = z.infer + +/** + * CustomSupportPrompts + */ + +export const customSupportPromptsSchema = z.record(z.string(), z.string().optional()) + +export type CustomSupportPrompts = z.infer diff --git a/packages/types/src/model.ts b/packages/types/src/model.ts new file mode 100644 index 0000000000..3bd66782cf --- /dev/null +++ b/packages/types/src/model.ts @@ -0,0 +1,63 @@ +import { z } from "zod" + +/** + * ReasoningEffort + */ + +export const reasoningEfforts = ["low", "medium", "high"] as const + +export const reasoningEffortsSchema = z.enum(reasoningEfforts) + +export type ReasoningEffort = z.infer + +/** + * ModelParameter + */ + +export const modelParameters = ["max_tokens", "temperature", "reasoning", "include_reasoning"] as const + +export const modelParametersSchema = z.enum(modelParameters) + +export type ModelParameter = z.infer + +export const isModelParameter = (value: string): value is ModelParameter => + modelParameters.includes(value as ModelParameter) + +/** + * ModelInfo + */ + +export const modelInfoSchema = z.object({ + maxTokens: z.number().nullish(), + maxThinkingTokens: z.number().nullish(), + contextWindow: z.number(), + supportsImages: z.boolean().optional(), + supportsComputerUse: z.boolean().optional(), + supportsPromptCache: z.boolean(), + supportsReasoningBudget: z.boolean().optional(), + requiredReasoningBudget: z.boolean().optional(), + supportsReasoningEffort: z.boolean().optional(), + supportedParameters: z.array(modelParametersSchema).optional(), + inputPrice: z.number().optional(), + outputPrice: z.number().optional(), + cacheWritesPrice: z.number().optional(), + cacheReadsPrice: z.number().optional(), + description: z.string().optional(), + reasoningEffort: reasoningEffortsSchema.optional(), + minTokensPerCachePoint: z.number().optional(), + maxCachePoints: z.number().optional(), + cachableFields: z.array(z.string()).optional(), + tiers: z + .array( + z.object({ + contextWindow: z.number(), + inputPrice: z.number().optional(), + outputPrice: z.number().optional(), + cacheWritesPrice: z.number().optional(), + cacheReadsPrice: z.number().optional(), + }), + ) + .optional(), +}) + +export type ModelInfo = z.infer diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts new file mode 100644 index 0000000000..6803cebc11 --- /dev/null +++ b/packages/types/src/provider-settings.ts @@ -0,0 +1,360 @@ +import { z } from "zod" + +import type { Keys } from "./type-fu.js" +import { reasoningEffortsSchema, modelInfoSchema } from "./model.js" +import { codebaseIndexProviderSchema } from "./codebase-index.js" + +/** + * ProviderName + */ + +export const providerNames = [ + "anthropic", + "glama", + "openrouter", + "bedrock", + "vertex", + "openai", + "ollama", + "vscode-lm", + "lmstudio", + "gemini", + "openai-native", + "mistral", + "deepseek", + "unbound", + "requesty", + "human-relay", + "fake-ai", + "xai", + "groq", + "chutes", + "litellm", +] as const + +export const providerNamesSchema = z.enum(providerNames) + +export type ProviderName = z.infer + +/** + * ProviderSettingsEntry + */ + +export const providerSettingsEntrySchema = z.object({ + id: z.string(), + name: z.string(), + apiProvider: providerNamesSchema.optional(), +}) + +export type ProviderSettingsEntry = z.infer + +/** + * ProviderSettings + */ + +const baseProviderSettingsSchema = z.object({ + includeMaxTokens: z.boolean().optional(), + diffEnabled: z.boolean().optional(), + fuzzyMatchThreshold: z.number().optional(), + modelTemperature: z.number().nullish(), + rateLimitSeconds: z.number().optional(), + + // Model reasoning. + enableReasoningEffort: z.boolean().optional(), + reasoningEffort: reasoningEffortsSchema.optional(), + modelMaxTokens: z.number().optional(), + modelMaxThinkingTokens: z.number().optional(), +}) + +// Several of the providers share common model config properties. +const apiModelIdProviderModelSchema = baseProviderSettingsSchema.extend({ + apiModelId: z.string().optional(), +}) + +const anthropicSchema = apiModelIdProviderModelSchema.extend({ + apiKey: z.string().optional(), + anthropicBaseUrl: z.string().optional(), + anthropicUseAuthToken: z.boolean().optional(), +}) + +const glamaSchema = baseProviderSettingsSchema.extend({ + glamaModelId: z.string().optional(), + glamaApiKey: z.string().optional(), +}) + +const openRouterSchema = baseProviderSettingsSchema.extend({ + openRouterApiKey: z.string().optional(), + openRouterModelId: z.string().optional(), + openRouterBaseUrl: z.string().optional(), + openRouterSpecificProvider: z.string().optional(), + openRouterUseMiddleOutTransform: z.boolean().optional(), +}) + +const bedrockSchema = apiModelIdProviderModelSchema.extend({ + awsAccessKey: z.string().optional(), + awsSecretKey: z.string().optional(), + awsSessionToken: z.string().optional(), + awsRegion: z.string().optional(), + awsUseCrossRegionInference: z.boolean().optional(), + awsUsePromptCache: z.boolean().optional(), + awsProfile: z.string().optional(), + awsUseProfile: z.boolean().optional(), + awsCustomArn: z.string().optional(), +}) + +const vertexSchema = apiModelIdProviderModelSchema.extend({ + vertexKeyFile: z.string().optional(), + vertexJsonCredentials: z.string().optional(), + vertexProjectId: z.string().optional(), + vertexRegion: z.string().optional(), +}) + +const openAiSchema = baseProviderSettingsSchema.extend({ + openAiBaseUrl: z.string().optional(), + openAiApiKey: z.string().optional(), + openAiLegacyFormat: z.boolean().optional(), + openAiR1FormatEnabled: z.boolean().optional(), + openAiModelId: z.string().optional(), + openAiCustomModelInfo: modelInfoSchema.nullish(), + openAiUseAzure: z.boolean().optional(), + azureApiVersion: z.string().optional(), + openAiStreamingEnabled: z.boolean().optional(), + openAiHostHeader: z.string().optional(), // Keep temporarily for backward compatibility during migration. + openAiHeaders: z.record(z.string(), z.string()).optional(), +}) + +const ollamaSchema = baseProviderSettingsSchema.extend({ + ollamaModelId: z.string().optional(), + ollamaBaseUrl: z.string().optional(), +}) + +const vsCodeLmSchema = baseProviderSettingsSchema.extend({ + vsCodeLmModelSelector: z + .object({ + vendor: z.string().optional(), + family: z.string().optional(), + version: z.string().optional(), + id: z.string().optional(), + }) + .optional(), +}) + +const lmStudioSchema = baseProviderSettingsSchema.extend({ + lmStudioModelId: z.string().optional(), + lmStudioBaseUrl: z.string().optional(), + lmStudioDraftModelId: z.string().optional(), + lmStudioSpeculativeDecodingEnabled: z.boolean().optional(), +}) + +const geminiSchema = apiModelIdProviderModelSchema.extend({ + geminiApiKey: z.string().optional(), + googleGeminiBaseUrl: z.string().optional(), +}) + +const openAiNativeSchema = apiModelIdProviderModelSchema.extend({ + openAiNativeApiKey: z.string().optional(), + openAiNativeBaseUrl: z.string().optional(), +}) + +const mistralSchema = apiModelIdProviderModelSchema.extend({ + mistralApiKey: z.string().optional(), + mistralCodestralUrl: z.string().optional(), +}) + +const deepSeekSchema = apiModelIdProviderModelSchema.extend({ + deepSeekBaseUrl: z.string().optional(), + deepSeekApiKey: z.string().optional(), +}) + +const unboundSchema = baseProviderSettingsSchema.extend({ + unboundApiKey: z.string().optional(), + unboundModelId: z.string().optional(), +}) + +const requestySchema = baseProviderSettingsSchema.extend({ + requestyApiKey: z.string().optional(), + requestyModelId: z.string().optional(), +}) + +const humanRelaySchema = baseProviderSettingsSchema + +const fakeAiSchema = baseProviderSettingsSchema.extend({ + fakeAi: z.unknown().optional(), +}) + +const xaiSchema = apiModelIdProviderModelSchema.extend({ + xaiApiKey: z.string().optional(), +}) + +const groqSchema = apiModelIdProviderModelSchema.extend({ + groqApiKey: z.string().optional(), +}) + +const chutesSchema = apiModelIdProviderModelSchema.extend({ + chutesApiKey: z.string().optional(), +}) + +const litellmSchema = baseProviderSettingsSchema.extend({ + litellmBaseUrl: z.string().optional(), + litellmApiKey: z.string().optional(), + litellmModelId: z.string().optional(), +}) + +const defaultSchema = z.object({ + apiProvider: z.undefined(), +}) + +export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [ + anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })), + glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })), + openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })), + bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), + vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })), + openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })), + ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })), + vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })), + lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), + geminiSchema.merge(z.object({ apiProvider: z.literal("gemini") })), + openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), + mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), + deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), + unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })), + requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })), + humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })), + fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })), + xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })), + groqSchema.merge(z.object({ apiProvider: z.literal("groq") })), + chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })), + litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), + defaultSchema, +]) + +export const providerSettingsSchema = z.object({ + apiProvider: providerNamesSchema.optional(), + ...anthropicSchema.shape, + ...glamaSchema.shape, + ...openRouterSchema.shape, + ...bedrockSchema.shape, + ...vertexSchema.shape, + ...openAiSchema.shape, + ...ollamaSchema.shape, + ...vsCodeLmSchema.shape, + ...lmStudioSchema.shape, + ...geminiSchema.shape, + ...openAiNativeSchema.shape, + ...mistralSchema.shape, + ...deepSeekSchema.shape, + ...unboundSchema.shape, + ...requestySchema.shape, + ...humanRelaySchema.shape, + ...fakeAiSchema.shape, + ...xaiSchema.shape, + ...groqSchema.shape, + ...chutesSchema.shape, + ...litellmSchema.shape, + ...codebaseIndexProviderSchema.shape, +}) + +export type ProviderSettings = z.infer + +type ProviderSettingsRecord = Record, undefined> + +const providerSettingsRecord: ProviderSettingsRecord = { + apiProvider: undefined, + // Anthropic + apiModelId: undefined, + apiKey: undefined, + anthropicBaseUrl: undefined, + anthropicUseAuthToken: undefined, + // Glama + glamaModelId: undefined, + glamaApiKey: undefined, + // OpenRouter + openRouterApiKey: undefined, + openRouterModelId: undefined, + openRouterBaseUrl: undefined, + openRouterSpecificProvider: undefined, + openRouterUseMiddleOutTransform: undefined, + // Amazon Bedrock + awsAccessKey: undefined, + awsSecretKey: undefined, + awsSessionToken: undefined, + awsRegion: undefined, + awsUseCrossRegionInference: undefined, + awsUsePromptCache: undefined, + awsProfile: undefined, + awsUseProfile: undefined, + awsCustomArn: undefined, + // Google Vertex + vertexKeyFile: undefined, + vertexJsonCredentials: undefined, + vertexProjectId: undefined, + vertexRegion: undefined, + // OpenAI + openAiBaseUrl: undefined, + openAiApiKey: undefined, + openAiLegacyFormat: undefined, + openAiR1FormatEnabled: undefined, + openAiModelId: undefined, + openAiCustomModelInfo: undefined, + openAiUseAzure: undefined, + azureApiVersion: undefined, + openAiStreamingEnabled: undefined, + openAiHostHeader: undefined, // Keep temporarily for backward compatibility during migration + openAiHeaders: undefined, + // Ollama + ollamaModelId: undefined, + ollamaBaseUrl: undefined, + // VS Code LM + vsCodeLmModelSelector: undefined, + lmStudioModelId: undefined, + lmStudioBaseUrl: undefined, + lmStudioDraftModelId: undefined, + lmStudioSpeculativeDecodingEnabled: undefined, + // Gemini + geminiApiKey: undefined, + googleGeminiBaseUrl: undefined, + // OpenAI Native + openAiNativeApiKey: undefined, + openAiNativeBaseUrl: undefined, + // Mistral + mistralApiKey: undefined, + mistralCodestralUrl: undefined, + // DeepSeek + deepSeekBaseUrl: undefined, + deepSeekApiKey: undefined, + // Unbound + unboundApiKey: undefined, + unboundModelId: undefined, + // Requesty + requestyApiKey: undefined, + requestyModelId: undefined, + // Code Index + codeIndexOpenAiKey: undefined, + codeIndexQdrantApiKey: undefined, + // Reasoning + enableReasoningEffort: undefined, + reasoningEffort: undefined, + modelMaxTokens: undefined, + modelMaxThinkingTokens: undefined, + // Generic + includeMaxTokens: undefined, + diffEnabled: undefined, + fuzzyMatchThreshold: undefined, + modelTemperature: undefined, + rateLimitSeconds: undefined, + // Fake AI + fakeAi: undefined, + // X.AI (Grok) + xaiApiKey: undefined, + // Groq + groqApiKey: undefined, + // Chutes AI + chutesApiKey: undefined, + // LiteLLM + litellmBaseUrl: undefined, + litellmApiKey: undefined, + litellmModelId: undefined, +} + +export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys[] diff --git a/packages/types/src/telemetry.ts b/packages/types/src/telemetry.ts new file mode 100644 index 0000000000..78e996f766 --- /dev/null +++ b/packages/types/src/telemetry.ts @@ -0,0 +1,134 @@ +import { z } from "zod" + +import { providerNames } from "./provider-settings.js" + +/** + * TelemetrySetting + */ + +export const telemetrySettings = ["unset", "enabled", "disabled"] as const + +export const telemetrySettingsSchema = z.enum(telemetrySettings) + +export type TelemetrySetting = z.infer + +/** + * TelemetryEventName + */ + +export enum TelemetryEventName { + TASK_CREATED = "Task Created", + TASK_RESTARTED = "Task Reopened", + TASK_COMPLETED = "Task Completed", + TASK_CONVERSATION_MESSAGE = "Conversation Message", + LLM_COMPLETION = "LLM Completion", + MODE_SWITCH = "Mode Switched", + TOOL_USED = "Tool Used", + + CHECKPOINT_CREATED = "Checkpoint Created", + CHECKPOINT_RESTORED = "Checkpoint Restored", + CHECKPOINT_DIFFED = "Checkpoint Diffed", + + CONTEXT_CONDENSED = "Context Condensed", + SLIDING_WINDOW_TRUNCATION = "Sliding Window Truncation", + + CODE_ACTION_USED = "Code Action Used", + PROMPT_ENHANCED = "Prompt Enhanced", + + TITLE_BUTTON_CLICKED = "Title Button Clicked", + + AUTHENTICATION_INITIATED = "Authentication Initiated", + + SCHEMA_VALIDATION_ERROR = "Schema Validation Error", + DIFF_APPLICATION_ERROR = "Diff Application Error", + SHELL_INTEGRATION_ERROR = "Shell Integration Error", + CONSECUTIVE_MISTAKE_ERROR = "Consecutive Mistake Error", +} + +/** + * TelemetryProperties + */ + +export const appPropertiesSchema = z.object({ + appVersion: z.string(), + vscodeVersion: z.string(), + platform: z.string(), + editorName: z.string(), + language: z.string(), + mode: z.string(), +}) + +export const taskPropertiesSchema = z.object({ + taskId: z.string().optional(), + apiProvider: z.enum(providerNames).optional(), + modelId: z.string().optional(), + diffStrategy: z.string().optional(), + isSubtask: z.boolean().optional(), +}) + +export const telemetryPropertiesSchema = z.object({ + ...appPropertiesSchema.shape, + ...taskPropertiesSchema.shape, +}) + +export type TelemetryProperties = z.infer + +/** + * TelemetryEvent + */ + +export type TelemetryEvent = { + event: TelemetryEventName + // eslint-disable-next-line @typescript-eslint/no-explicit-any + properties?: Record +} + +/** + * RooCodeTelemetryEvent + */ + +const completionPropertiesSchema = z.object({ + inputTokens: z.number(), + outputTokens: z.number(), + cacheReadTokens: z.number().optional(), + cacheWriteTokens: z.number().optional(), + cost: z.number().optional(), +}) + +export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [ + z.object({ + type: z.enum([ + TelemetryEventName.TASK_CREATED, + TelemetryEventName.TASK_RESTARTED, + TelemetryEventName.TASK_COMPLETED, + TelemetryEventName.TASK_CONVERSATION_MESSAGE, + TelemetryEventName.MODE_SWITCH, + TelemetryEventName.TOOL_USED, + TelemetryEventName.CHECKPOINT_CREATED, + TelemetryEventName.CHECKPOINT_RESTORED, + TelemetryEventName.CHECKPOINT_DIFFED, + TelemetryEventName.CODE_ACTION_USED, + TelemetryEventName.PROMPT_ENHANCED, + TelemetryEventName.TITLE_BUTTON_CLICKED, + TelemetryEventName.AUTHENTICATION_INITIATED, + TelemetryEventName.SCHEMA_VALIDATION_ERROR, + TelemetryEventName.DIFF_APPLICATION_ERROR, + TelemetryEventName.SHELL_INTEGRATION_ERROR, + TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR, + ]), + properties: z.object({ + ...appPropertiesSchema.shape, + ...taskPropertiesSchema.shape, + }), + }), + z.object({ + type: z.literal(TelemetryEventName.LLM_COMPLETION), + properties: z.object({ + ...appPropertiesSchema.shape, + ...taskPropertiesSchema.shape, + ...completionPropertiesSchema.shape, + }), + }), +]) + +export type RooCodeTelemetryEvent = z.infer diff --git a/packages/types/src/terminal.ts b/packages/types/src/terminal.ts new file mode 100644 index 0000000000..51d6f252a9 --- /dev/null +++ b/packages/types/src/terminal.ts @@ -0,0 +1,30 @@ +import { z } from "zod" + +/** + * CommandExecutionStatus + */ + +export const commandExecutionStatusSchema = z.discriminatedUnion("status", [ + z.object({ + executionId: z.string(), + status: z.literal("started"), + pid: z.number().optional(), + command: z.string(), + }), + z.object({ + executionId: z.string(), + status: z.literal("output"), + output: z.string(), + }), + z.object({ + executionId: z.string(), + status: z.literal("exited"), + exitCode: z.number().optional(), + }), + z.object({ + executionId: z.string(), + status: z.literal("fallback"), + }), +]) + +export type CommandExecutionStatus = z.infer diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts new file mode 100644 index 0000000000..9e807d639d --- /dev/null +++ b/packages/types/src/tool.ts @@ -0,0 +1,54 @@ +import { z } from "zod" + +/** + * ToolGroup + */ + +export const toolGroups = ["read", "edit", "browser", "command", "mcp", "modes"] as const + +export const toolGroupsSchema = z.enum(toolGroups) + +export type ToolGroup = z.infer + +/** + * ToolName + */ + +export const toolNames = [ + "execute_command", + "read_file", + "write_to_file", + "apply_diff", + "insert_content", + "search_and_replace", + "search_files", + "list_files", + "list_code_definition_names", + "browser_action", + "use_mcp_tool", + "access_mcp_resource", + "ask_followup_question", + "attempt_completion", + "switch_mode", + "new_task", + "fetch_instructions", + "codebase_search", +] as const + +export const toolNamesSchema = z.enum(toolNames) + +export type ToolName = z.infer + +/** + * ToolUsage + */ + +export const toolUsageSchema = z.record( + toolNamesSchema, + z.object({ + attempts: z.number(), + failures: z.number(), + }), +) + +export type ToolUsage = z.infer diff --git a/packages/types/src/type-fu.ts b/packages/types/src/type-fu.ts new file mode 100644 index 0000000000..0014e9b187 --- /dev/null +++ b/packages/types/src/type-fu.ts @@ -0,0 +1,11 @@ +/** + * TS + */ + +export type Keys = keyof T + +export type Values = T[keyof T] + +export type Equals = (() => T extends X ? 1 : 2) extends () => T extends Y ? 1 : 2 ? true : false + +export type AssertEqual = T diff --git a/packages/types/src/types.ts b/packages/types/src/types.ts deleted file mode 100644 index 0bb2f71de2..0000000000 --- a/packages/types/src/types.ts +++ /dev/null @@ -1,1344 +0,0 @@ -import { z } from "zod" - -/** - * TS - */ - -export type Keys = keyof T - -export type Values = T[keyof T] - -export type Equals = (() => T extends X ? 1 : 2) extends () => T extends Y ? 1 : 2 ? true : false - -export type AssertEqual = T - -/** - * CodeAction - */ - -export const codeActionIds = ["explainCode", "fixCode", "improveCode", "addToContext", "newTask"] as const - -export type CodeActionId = (typeof codeActionIds)[number] - -export type CodeActionName = "EXPLAIN" | "FIX" | "IMPROVE" | "ADD_TO_CONTEXT" | "NEW_TASK" - -/** - * TerminalAction - */ - -export const terminalActionIds = ["terminalAddToContext", "terminalFixCommand", "terminalExplainCommand"] as const - -export type TerminalActionId = (typeof terminalActionIds)[number] - -export type TerminalActionName = "ADD_TO_CONTEXT" | "FIX" | "EXPLAIN" - -export type TerminalActionPromptType = `TERMINAL_${TerminalActionName}` - -/** - * Command - */ - -export const commandIds = [ - "activationCompleted", - - "plusButtonClicked", - "promptsButtonClicked", - "mcpButtonClicked", - "historyButtonClicked", - "popoutButtonClicked", - "settingsButtonClicked", - - "openInNewTab", - - "showHumanRelayDialog", - "registerHumanRelayCallback", - "unregisterHumanRelayCallback", - "handleHumanRelayResponse", - - "newTask", - - "setCustomStoragePath", - - "focusInput", - "acceptInput", -] as const - -export type CommandId = (typeof commandIds)[number] - -/** - * ProviderName - */ - -export const providerNames = [ - "anthropic", - "glama", - "openrouter", - "bedrock", - "vertex", - "openai", - "ollama", - "vscode-lm", - "lmstudio", - "gemini", - "openai-native", - "mistral", - "deepseek", - "unbound", - "requesty", - "human-relay", - "fake-ai", - "xai", - "groq", - "chutes", - "litellm", -] as const - -export const providerNamesSchema = z.enum(providerNames) - -export type ProviderName = z.infer - -/** - * ToolGroup - */ - -export const toolGroups = ["read", "edit", "browser", "command", "mcp", "modes"] as const - -export const toolGroupsSchema = z.enum(toolGroups) - -export type ToolGroup = z.infer - -/** - * Language - */ - -export const languages = [ - "ca", - "de", - "en", - "es", - "fr", - "hi", - "it", - "ja", - "ko", - "nl", - "pl", - "pt-BR", - "ru", - "tr", - "vi", - "zh-CN", - "zh-TW", -] as const - -export const languagesSchema = z.enum(languages) - -export type Language = z.infer - -export const isLanguage = (value: string): value is Language => languages.includes(value as Language) - -/** - * TelemetrySetting - */ - -export const telemetrySettings = ["unset", "enabled", "disabled"] as const - -export const telemetrySettingsSchema = z.enum(telemetrySettings) - -export type TelemetrySetting = z.infer - -/** - * ReasoningEffort - */ - -export const reasoningEfforts = ["low", "medium", "high"] as const - -export const reasoningEffortsSchema = z.enum(reasoningEfforts) - -export type ReasoningEffort = z.infer - -/** - * ModelParameter - */ - -export const modelParameters = ["max_tokens", "temperature", "reasoning", "include_reasoning"] as const - -export const modelParametersSchema = z.enum(modelParameters) - -export type ModelParameter = z.infer - -export const isModelParameter = (value: string): value is ModelParameter => - modelParameters.includes(value as ModelParameter) - -/** - * ModelInfo - */ - -export const modelInfoSchema = z.object({ - maxTokens: z.number().nullish(), - maxThinkingTokens: z.number().nullish(), - contextWindow: z.number(), - supportsImages: z.boolean().optional(), - supportsComputerUse: z.boolean().optional(), - supportsPromptCache: z.boolean(), - supportsReasoningBudget: z.boolean().optional(), - requiredReasoningBudget: z.boolean().optional(), - supportsReasoningEffort: z.boolean().optional(), - supportedParameters: z.array(modelParametersSchema).optional(), - inputPrice: z.number().optional(), - outputPrice: z.number().optional(), - cacheWritesPrice: z.number().optional(), - cacheReadsPrice: z.number().optional(), - description: z.string().optional(), - reasoningEffort: reasoningEffortsSchema.optional(), - minTokensPerCachePoint: z.number().optional(), - maxCachePoints: z.number().optional(), - cachableFields: z.array(z.string()).optional(), - tiers: z - .array( - z.object({ - contextWindow: z.number(), - inputPrice: z.number().optional(), - outputPrice: z.number().optional(), - cacheWritesPrice: z.number().optional(), - cacheReadsPrice: z.number().optional(), - }), - ) - .optional(), -}) - -export type ModelInfo = z.infer - -/** - * Codebase Index Config - */ -export const codebaseIndexConfigSchema = z.object({ - codebaseIndexEnabled: z.boolean().optional(), - codebaseIndexQdrantUrl: z.string().optional(), - codebaseIndexEmbedderProvider: z.enum(["openai", "ollama"]).optional(), - codebaseIndexEmbedderBaseUrl: z.string().optional(), - codebaseIndexEmbedderModelId: z.string().optional(), -}) - -export type CodebaseIndexConfig = z.infer - -export const codebaseIndexModelsSchema = z.object({ - openai: z.record(z.string(), z.object({ dimension: z.number() })).optional(), - ollama: z.record(z.string(), z.object({ dimension: z.number() })).optional(), -}) - -export type CodebaseIndexModels = z.infer - -export const codebaseIndexProviderSchema = z.object({ - codeIndexOpenAiKey: z.string().optional(), - codeIndexQdrantApiKey: z.string().optional(), -}) - -/** - * HistoryItem - */ - -export const historyItemSchema = z.object({ - id: z.string(), - number: z.number(), - ts: z.number(), - task: z.string(), - tokensIn: z.number(), - tokensOut: z.number(), - cacheWrites: z.number().optional(), - cacheReads: z.number().optional(), - totalCost: z.number(), - size: z.number().optional(), - workspace: z.string().optional(), -}) - -export type HistoryItem = z.infer - -/** - * GroupOptions - */ - -export const groupOptionsSchema = z.object({ - fileRegex: z - .string() - .optional() - .refine( - (pattern) => { - if (!pattern) { - return true // Optional, so empty is valid. - } - - try { - new RegExp(pattern) - return true - } catch { - return false - } - }, - { message: "Invalid regular expression pattern" }, - ), - description: z.string().optional(), -}) - -export type GroupOptions = z.infer - -/** - * GroupEntry - */ - -export const groupEntrySchema = z.union([toolGroupsSchema, z.tuple([toolGroupsSchema, groupOptionsSchema])]) - -export type GroupEntry = z.infer - -/** - * ModeConfig - */ - -const groupEntryArraySchema = z.array(groupEntrySchema).refine( - (groups) => { - const seen = new Set() - - return groups.every((group) => { - // For tuples, check the group name (first element). - const groupName = Array.isArray(group) ? group[0] : group - - if (seen.has(groupName)) { - return false - } - - seen.add(groupName) - return true - }) - }, - { message: "Duplicate groups are not allowed" }, -) - -export const modeConfigSchema = z.object({ - slug: z.string().regex(/^[a-zA-Z0-9-]+$/, "Slug must contain only letters numbers and dashes"), - name: z.string().min(1, "Name is required"), - roleDefinition: z.string().min(1, "Role definition is required"), - whenToUse: z.string().optional(), - customInstructions: z.string().optional(), - groups: groupEntryArraySchema, - source: z.enum(["global", "project"]).optional(), -}) - -export type ModeConfig = z.infer - -/** - * CustomModesSettings - */ - -export const customModesSettingsSchema = z.object({ - customModes: z.array(modeConfigSchema).refine( - (modes) => { - const slugs = new Set() - - return modes.every((mode) => { - if (slugs.has(mode.slug)) { - return false - } - - slugs.add(mode.slug) - return true - }) - }, - { - message: "Duplicate mode slugs are not allowed", - }, - ), -}) - -export type CustomModesSettings = z.infer - -/** - * PromptComponent - */ - -export const promptComponentSchema = z.object({ - roleDefinition: z.string().optional(), - whenToUse: z.string().optional(), - customInstructions: z.string().optional(), -}) - -export type PromptComponent = z.infer - -/** - * CustomModePrompts - */ - -export const customModePromptsSchema = z.record(z.string(), promptComponentSchema.optional()) - -export type CustomModePrompts = z.infer - -/** - * CustomSupportPrompts - */ - -export const customSupportPromptsSchema = z.record(z.string(), z.string().optional()) - -export type CustomSupportPrompts = z.infer - -/** - * CommandExecutionStatus - */ - -export const commandExecutionStatusSchema = z.discriminatedUnion("status", [ - z.object({ - executionId: z.string(), - status: z.literal("started"), - pid: z.number().optional(), - command: z.string(), - }), - z.object({ - executionId: z.string(), - status: z.literal("output"), - output: z.string(), - }), - z.object({ - executionId: z.string(), - status: z.literal("exited"), - exitCode: z.number().optional(), - }), - z.object({ - executionId: z.string(), - status: z.literal("fallback"), - }), -]) - -export type CommandExecutionStatus = z.infer - -/** - * ExperimentId - */ - -export const experimentIds = ["autoCondenseContext", "powerSteering"] as const - -export const experimentIdsSchema = z.enum(experimentIds) - -export type ExperimentId = z.infer - -/** - * Experiments - */ - -const experimentsSchema = z.object({ - autoCondenseContext: z.boolean(), - powerSteering: z.boolean(), -}) - -export type Experiments = z.infer - -type _AssertExperiments = AssertEqual>> - -/** - * ProviderSettingsEntry - */ - -export const providerSettingsEntrySchema = z.object({ - id: z.string(), - name: z.string(), - apiProvider: providerNamesSchema.optional(), -}) - -export type ProviderSettingsEntry = z.infer - -/** - * ProviderSettings - */ - -const baseProviderSettingsSchema = z.object({ - includeMaxTokens: z.boolean().optional(), - diffEnabled: z.boolean().optional(), - fuzzyMatchThreshold: z.number().optional(), - modelTemperature: z.number().nullish(), - rateLimitSeconds: z.number().optional(), - - // Model reasoning. - enableReasoningEffort: z.boolean().optional(), - reasoningEffort: reasoningEffortsSchema.optional(), - modelMaxTokens: z.number().optional(), - modelMaxThinkingTokens: z.number().optional(), -}) - -// Several of the providers share common model config properties. -const apiModelIdProviderModelSchema = baseProviderSettingsSchema.extend({ - apiModelId: z.string().optional(), -}) - -const anthropicSchema = apiModelIdProviderModelSchema.extend({ - apiKey: z.string().optional(), - anthropicBaseUrl: z.string().optional(), - anthropicUseAuthToken: z.boolean().optional(), -}) - -const glamaSchema = baseProviderSettingsSchema.extend({ - glamaModelId: z.string().optional(), - glamaApiKey: z.string().optional(), -}) - -const openRouterSchema = baseProviderSettingsSchema.extend({ - openRouterApiKey: z.string().optional(), - openRouterModelId: z.string().optional(), - openRouterBaseUrl: z.string().optional(), - openRouterSpecificProvider: z.string().optional(), - openRouterUseMiddleOutTransform: z.boolean().optional(), -}) - -const bedrockSchema = apiModelIdProviderModelSchema.extend({ - awsAccessKey: z.string().optional(), - awsSecretKey: z.string().optional(), - awsSessionToken: z.string().optional(), - awsRegion: z.string().optional(), - awsUseCrossRegionInference: z.boolean().optional(), - awsUsePromptCache: z.boolean().optional(), - awsProfile: z.string().optional(), - awsUseProfile: z.boolean().optional(), - awsCustomArn: z.string().optional(), -}) - -const vertexSchema = apiModelIdProviderModelSchema.extend({ - vertexKeyFile: z.string().optional(), - vertexJsonCredentials: z.string().optional(), - vertexProjectId: z.string().optional(), - vertexRegion: z.string().optional(), -}) - -const openAiSchema = baseProviderSettingsSchema.extend({ - openAiBaseUrl: z.string().optional(), - openAiApiKey: z.string().optional(), - openAiLegacyFormat: z.boolean().optional(), - openAiR1FormatEnabled: z.boolean().optional(), - openAiModelId: z.string().optional(), - openAiCustomModelInfo: modelInfoSchema.nullish(), - openAiUseAzure: z.boolean().optional(), - azureApiVersion: z.string().optional(), - openAiStreamingEnabled: z.boolean().optional(), - openAiHostHeader: z.string().optional(), // Keep temporarily for backward compatibility during migration. - openAiHeaders: z.record(z.string(), z.string()).optional(), -}) - -const ollamaSchema = baseProviderSettingsSchema.extend({ - ollamaModelId: z.string().optional(), - ollamaBaseUrl: z.string().optional(), -}) - -const vsCodeLmSchema = baseProviderSettingsSchema.extend({ - vsCodeLmModelSelector: z - .object({ - vendor: z.string().optional(), - family: z.string().optional(), - version: z.string().optional(), - id: z.string().optional(), - }) - .optional(), -}) - -const lmStudioSchema = baseProviderSettingsSchema.extend({ - lmStudioModelId: z.string().optional(), - lmStudioBaseUrl: z.string().optional(), - lmStudioDraftModelId: z.string().optional(), - lmStudioSpeculativeDecodingEnabled: z.boolean().optional(), -}) - -const geminiSchema = apiModelIdProviderModelSchema.extend({ - geminiApiKey: z.string().optional(), - googleGeminiBaseUrl: z.string().optional(), -}) - -const openAiNativeSchema = apiModelIdProviderModelSchema.extend({ - openAiNativeApiKey: z.string().optional(), - openAiNativeBaseUrl: z.string().optional(), -}) - -const mistralSchema = apiModelIdProviderModelSchema.extend({ - mistralApiKey: z.string().optional(), - mistralCodestralUrl: z.string().optional(), -}) - -const deepSeekSchema = apiModelIdProviderModelSchema.extend({ - deepSeekBaseUrl: z.string().optional(), - deepSeekApiKey: z.string().optional(), -}) - -const unboundSchema = baseProviderSettingsSchema.extend({ - unboundApiKey: z.string().optional(), - unboundModelId: z.string().optional(), -}) - -const requestySchema = baseProviderSettingsSchema.extend({ - requestyApiKey: z.string().optional(), - requestyModelId: z.string().optional(), -}) - -const humanRelaySchema = baseProviderSettingsSchema - -const fakeAiSchema = baseProviderSettingsSchema.extend({ - fakeAi: z.unknown().optional(), -}) - -const xaiSchema = apiModelIdProviderModelSchema.extend({ - xaiApiKey: z.string().optional(), -}) - -const groqSchema = apiModelIdProviderModelSchema.extend({ - groqApiKey: z.string().optional(), -}) - -const chutesSchema = apiModelIdProviderModelSchema.extend({ - chutesApiKey: z.string().optional(), -}) - -const litellmSchema = baseProviderSettingsSchema.extend({ - litellmBaseUrl: z.string().optional(), - litellmApiKey: z.string().optional(), - litellmModelId: z.string().optional(), -}) - -const defaultSchema = z.object({ - apiProvider: z.undefined(), -}) - -export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [ - anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })), - glamaSchema.merge(z.object({ apiProvider: z.literal("glama") })), - openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })), - bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), - vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })), - openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })), - ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })), - vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })), - lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), - geminiSchema.merge(z.object({ apiProvider: z.literal("gemini") })), - openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), - mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), - deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), - unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })), - requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })), - humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })), - fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })), - xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })), - groqSchema.merge(z.object({ apiProvider: z.literal("groq") })), - chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })), - litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), - defaultSchema, -]) - -export const providerSettingsSchema = z.object({ - apiProvider: providerNamesSchema.optional(), - ...anthropicSchema.shape, - ...glamaSchema.shape, - ...openRouterSchema.shape, - ...bedrockSchema.shape, - ...vertexSchema.shape, - ...openAiSchema.shape, - ...ollamaSchema.shape, - ...vsCodeLmSchema.shape, - ...lmStudioSchema.shape, - ...geminiSchema.shape, - ...openAiNativeSchema.shape, - ...mistralSchema.shape, - ...deepSeekSchema.shape, - ...unboundSchema.shape, - ...requestySchema.shape, - ...humanRelaySchema.shape, - ...fakeAiSchema.shape, - ...xaiSchema.shape, - ...groqSchema.shape, - ...chutesSchema.shape, - ...litellmSchema.shape, - ...codebaseIndexProviderSchema.shape, -}) - -export type ProviderSettings = z.infer - -type ProviderSettingsRecord = Record, undefined> - -const providerSettingsRecord: ProviderSettingsRecord = { - apiProvider: undefined, - // Anthropic - apiModelId: undefined, - apiKey: undefined, - anthropicBaseUrl: undefined, - anthropicUseAuthToken: undefined, - // Glama - glamaModelId: undefined, - glamaApiKey: undefined, - // OpenRouter - openRouterApiKey: undefined, - openRouterModelId: undefined, - openRouterBaseUrl: undefined, - openRouterSpecificProvider: undefined, - openRouterUseMiddleOutTransform: undefined, - // Amazon Bedrock - awsAccessKey: undefined, - awsSecretKey: undefined, - awsSessionToken: undefined, - awsRegion: undefined, - awsUseCrossRegionInference: undefined, - awsUsePromptCache: undefined, - awsProfile: undefined, - awsUseProfile: undefined, - awsCustomArn: undefined, - // Google Vertex - vertexKeyFile: undefined, - vertexJsonCredentials: undefined, - vertexProjectId: undefined, - vertexRegion: undefined, - // OpenAI - openAiBaseUrl: undefined, - openAiApiKey: undefined, - openAiLegacyFormat: undefined, - openAiR1FormatEnabled: undefined, - openAiModelId: undefined, - openAiCustomModelInfo: undefined, - openAiUseAzure: undefined, - azureApiVersion: undefined, - openAiStreamingEnabled: undefined, - openAiHostHeader: undefined, // Keep temporarily for backward compatibility during migration - openAiHeaders: undefined, - // Ollama - ollamaModelId: undefined, - ollamaBaseUrl: undefined, - // VS Code LM - vsCodeLmModelSelector: undefined, - lmStudioModelId: undefined, - lmStudioBaseUrl: undefined, - lmStudioDraftModelId: undefined, - lmStudioSpeculativeDecodingEnabled: undefined, - // Gemini - geminiApiKey: undefined, - googleGeminiBaseUrl: undefined, - // OpenAI Native - openAiNativeApiKey: undefined, - openAiNativeBaseUrl: undefined, - // Mistral - mistralApiKey: undefined, - mistralCodestralUrl: undefined, - // DeepSeek - deepSeekBaseUrl: undefined, - deepSeekApiKey: undefined, - // Unbound - unboundApiKey: undefined, - unboundModelId: undefined, - // Requesty - requestyApiKey: undefined, - requestyModelId: undefined, - // Code Index - codeIndexOpenAiKey: undefined, - codeIndexQdrantApiKey: undefined, - // Reasoning - enableReasoningEffort: undefined, - reasoningEffort: undefined, - modelMaxTokens: undefined, - modelMaxThinkingTokens: undefined, - // Generic - includeMaxTokens: undefined, - diffEnabled: undefined, - fuzzyMatchThreshold: undefined, - modelTemperature: undefined, - rateLimitSeconds: undefined, - // Fake AI - fakeAi: undefined, - // X.AI (Grok) - xaiApiKey: undefined, - // Groq - groqApiKey: undefined, - // Chutes AI - chutesApiKey: undefined, - // LiteLLM - litellmBaseUrl: undefined, - litellmApiKey: undefined, - litellmModelId: undefined, -} - -export const PROVIDER_SETTINGS_KEYS = Object.keys(providerSettingsRecord) as Keys[] - -/** - * GlobalSettings - */ - -export const globalSettingsSchema = z.object({ - currentApiConfigName: z.string().optional(), - listApiConfigMeta: z.array(providerSettingsEntrySchema).optional(), - pinnedApiConfigs: z.record(z.string(), z.boolean()).optional(), - - lastShownAnnouncementId: z.string().optional(), - customInstructions: z.string().optional(), - taskHistory: z.array(historyItemSchema).optional(), - - condensingApiConfigId: z.string().optional(), - customCondensingPrompt: z.string().optional(), - - autoApprovalEnabled: z.boolean().optional(), - alwaysAllowReadOnly: z.boolean().optional(), - alwaysAllowReadOnlyOutsideWorkspace: z.boolean().optional(), - codebaseIndexModels: codebaseIndexModelsSchema.optional(), - codebaseIndexConfig: codebaseIndexConfigSchema.optional(), - alwaysAllowWrite: z.boolean().optional(), - alwaysAllowWriteOutsideWorkspace: z.boolean().optional(), - writeDelayMs: z.number().optional(), - alwaysAllowBrowser: z.boolean().optional(), - alwaysApproveResubmit: z.boolean().optional(), - requestDelaySeconds: z.number().optional(), - alwaysAllowMcp: z.boolean().optional(), - alwaysAllowModeSwitch: z.boolean().optional(), - alwaysAllowSubtasks: z.boolean().optional(), - alwaysAllowExecute: z.boolean().optional(), - allowedCommands: z.array(z.string()).optional(), - allowedMaxRequests: z.number().nullish(), - autoCondenseContextPercent: z.number().optional(), - - browserToolEnabled: z.boolean().optional(), - browserViewportSize: z.string().optional(), - screenshotQuality: z.number().optional(), - remoteBrowserEnabled: z.boolean().optional(), - remoteBrowserHost: z.string().optional(), - cachedChromeHostUrl: z.string().optional(), - - enableCheckpoints: z.boolean().optional(), - - ttsEnabled: z.boolean().optional(), - ttsSpeed: z.number().optional(), - soundEnabled: z.boolean().optional(), - soundVolume: z.number().optional(), - - maxOpenTabsContext: z.number().optional(), - maxWorkspaceFiles: z.number().optional(), - showRooIgnoredFiles: z.boolean().optional(), - maxReadFileLine: z.number().optional(), - - terminalOutputLineLimit: z.number().optional(), - terminalShellIntegrationTimeout: z.number().optional(), - terminalShellIntegrationDisabled: z.boolean().optional(), - terminalCommandDelay: z.number().optional(), - terminalPowershellCounter: z.boolean().optional(), - terminalZshClearEolMark: z.boolean().optional(), - terminalZshOhMy: z.boolean().optional(), - terminalZshP10k: z.boolean().optional(), - terminalZdotdir: z.boolean().optional(), - terminalCompressProgressBar: z.boolean().optional(), - - rateLimitSeconds: z.number().optional(), - diffEnabled: z.boolean().optional(), - fuzzyMatchThreshold: z.number().optional(), - experiments: experimentsSchema.optional(), - - language: languagesSchema.optional(), - - telemetrySetting: telemetrySettingsSchema.optional(), - - mcpEnabled: z.boolean().optional(), - enableMcpServerCreation: z.boolean().optional(), - - mode: z.string().optional(), - modeApiConfigs: z.record(z.string(), z.string()).optional(), - customModes: z.array(modeConfigSchema).optional(), - customModePrompts: customModePromptsSchema.optional(), - customSupportPrompts: customSupportPromptsSchema.optional(), - enhancementApiConfigId: z.string().optional(), - historyPreviewCollapsed: z.boolean().optional(), -}) - -export type GlobalSettings = z.infer - -type GlobalSettingsRecord = Record, undefined> - -const globalSettingsRecord: GlobalSettingsRecord = { - codebaseIndexModels: undefined, - codebaseIndexConfig: undefined, - currentApiConfigName: undefined, - listApiConfigMeta: undefined, - pinnedApiConfigs: undefined, - - lastShownAnnouncementId: undefined, - customInstructions: undefined, - taskHistory: undefined, - - condensingApiConfigId: undefined, - customCondensingPrompt: undefined, - - autoApprovalEnabled: undefined, - alwaysAllowReadOnly: undefined, - alwaysAllowReadOnlyOutsideWorkspace: undefined, - alwaysAllowWrite: undefined, - alwaysAllowWriteOutsideWorkspace: undefined, - writeDelayMs: undefined, - alwaysAllowBrowser: undefined, - alwaysApproveResubmit: undefined, - requestDelaySeconds: undefined, - alwaysAllowMcp: undefined, - alwaysAllowModeSwitch: undefined, - alwaysAllowSubtasks: undefined, - alwaysAllowExecute: undefined, - allowedCommands: undefined, - allowedMaxRequests: undefined, - autoCondenseContextPercent: undefined, - - browserToolEnabled: undefined, - browserViewportSize: undefined, - screenshotQuality: undefined, - remoteBrowserEnabled: undefined, - remoteBrowserHost: undefined, - - enableCheckpoints: undefined, - - ttsEnabled: undefined, - ttsSpeed: undefined, - soundEnabled: undefined, - soundVolume: undefined, - - maxOpenTabsContext: undefined, - maxWorkspaceFiles: undefined, - showRooIgnoredFiles: undefined, - maxReadFileLine: undefined, - - terminalOutputLineLimit: undefined, - terminalShellIntegrationTimeout: undefined, - terminalShellIntegrationDisabled: undefined, - terminalCommandDelay: undefined, - terminalPowershellCounter: undefined, - terminalZshClearEolMark: undefined, - terminalZshOhMy: undefined, - terminalZshP10k: undefined, - terminalZdotdir: undefined, - terminalCompressProgressBar: undefined, - - rateLimitSeconds: undefined, - diffEnabled: undefined, - fuzzyMatchThreshold: undefined, - experiments: undefined, - - language: undefined, - - telemetrySetting: undefined, - - mcpEnabled: undefined, - enableMcpServerCreation: undefined, - - mode: undefined, - modeApiConfigs: undefined, - customModes: undefined, - customModePrompts: undefined, - customSupportPrompts: undefined, - enhancementApiConfigId: undefined, - cachedChromeHostUrl: undefined, - historyPreviewCollapsed: undefined, -} - -export const GLOBAL_SETTINGS_KEYS = Object.keys(globalSettingsRecord) as Keys[] - -/** - * RooCodeSettings - */ - -export const rooCodeSettingsSchema = providerSettingsSchema.merge(globalSettingsSchema) - -export type RooCodeSettings = GlobalSettings & ProviderSettings - -/** - * SecretState - */ - -export type SecretState = Pick< - ProviderSettings, - | "apiKey" - | "glamaApiKey" - | "openRouterApiKey" - | "awsAccessKey" - | "awsSecretKey" - | "awsSessionToken" - | "openAiApiKey" - | "geminiApiKey" - | "openAiNativeApiKey" - | "deepSeekApiKey" - | "mistralApiKey" - | "unboundApiKey" - | "requestyApiKey" - | "xaiApiKey" - | "groqApiKey" - | "chutesApiKey" - | "litellmApiKey" - | "codeIndexOpenAiKey" - | "codeIndexQdrantApiKey" -> - -export type CodeIndexSecrets = "codeIndexOpenAiKey" | "codeIndexQdrantApiKey" - -type SecretStateRecord = Record, undefined> - -const secretStateRecord: SecretStateRecord = { - apiKey: undefined, - glamaApiKey: undefined, - openRouterApiKey: undefined, - awsAccessKey: undefined, - awsSecretKey: undefined, - awsSessionToken: undefined, - openAiApiKey: undefined, - geminiApiKey: undefined, - openAiNativeApiKey: undefined, - deepSeekApiKey: undefined, - mistralApiKey: undefined, - unboundApiKey: undefined, - requestyApiKey: undefined, - xaiApiKey: undefined, - groqApiKey: undefined, - chutesApiKey: undefined, - litellmApiKey: undefined, - codeIndexOpenAiKey: undefined, - codeIndexQdrantApiKey: undefined, -} - -export const SECRET_STATE_KEYS = Object.keys(secretStateRecord) as Keys[] - -export const isSecretStateKey = (key: string): key is Keys => - SECRET_STATE_KEYS.includes(key as Keys) - -/** - * GlobalState - */ - -export type GlobalState = Omit> - -export const GLOBAL_STATE_KEYS = [...GLOBAL_SETTINGS_KEYS, ...PROVIDER_SETTINGS_KEYS].filter( - (key: Keys) => !SECRET_STATE_KEYS.includes(key as Keys), -) as Keys[] - -export const isGlobalStateKey = (key: string): key is Keys => - GLOBAL_STATE_KEYS.includes(key as Keys) - -/** - * ClineAsk - */ - -export const clineAsks = [ - "followup", - "command", - "command_output", - "completion_result", - "tool", - "api_req_failed", - "resume_task", - "resume_completed_task", - "mistake_limit_reached", - "browser_action_launch", - "use_mcp_server", - "auto_approval_max_req_reached", -] as const - -export const clineAskSchema = z.enum(clineAsks) - -export type ClineAsk = z.infer - -// ClineSay - -export const clineSays = [ - "error", - "api_req_started", - "api_req_finished", - "api_req_retried", - "api_req_retry_delayed", - "api_req_deleted", - "text", - "reasoning", - "completion_result", - "user_feedback", - "user_feedback_diff", - "command_output", - "shell_integration_warning", - "browser_action", - "browser_action_result", - "mcp_server_request_started", - "mcp_server_response", - "subtask_result", - "checkpoint_saved", - "rooignore_error", - "diff_error", - "condense_context", - "codebase_search_result", -] as const - -export const clineSaySchema = z.enum(clineSays) - -export type ClineSay = z.infer - -/** - * ToolProgressStatus - */ - -export const toolProgressStatusSchema = z.object({ - icon: z.string().optional(), - text: z.string().optional(), -}) - -export type ToolProgressStatus = z.infer - -/** - * ContextCondense - */ - -export const contextCondenseSchema = z.object({ - cost: z.number(), - prevContextTokens: z.number(), - newContextTokens: z.number(), - summary: z.string(), -}) - -export type ContextCondense = z.infer - -/** - * ClineMessage - */ - -export const clineMessageSchema = z.object({ - ts: z.number(), - type: z.union([z.literal("ask"), z.literal("say")]), - ask: clineAskSchema.optional(), - say: clineSaySchema.optional(), - text: z.string().optional(), - images: z.array(z.string()).optional(), - partial: z.boolean().optional(), - reasoning: z.string().optional(), - conversationHistoryIndex: z.number().optional(), - checkpoint: z.record(z.string(), z.unknown()).optional(), - progressStatus: toolProgressStatusSchema.optional(), - contextCondense: contextCondenseSchema.optional(), -}) - -export type ClineMessage = z.infer - -/** - * TokenUsage - */ - -export const tokenUsageSchema = z.object({ - totalTokensIn: z.number(), - totalTokensOut: z.number(), - totalCacheWrites: z.number().optional(), - totalCacheReads: z.number().optional(), - totalCost: z.number(), - contextTokens: z.number(), -}) - -export type TokenUsage = z.infer - -/** - * ToolName - */ - -export const toolNames = [ - "execute_command", - "read_file", - "write_to_file", - "apply_diff", - "insert_content", - "search_and_replace", - "search_files", - "list_files", - "list_code_definition_names", - "browser_action", - "use_mcp_tool", - "access_mcp_resource", - "ask_followup_question", - "attempt_completion", - "switch_mode", - "new_task", - "fetch_instructions", - "codebase_search", -] as const - -export const toolNamesSchema = z.enum(toolNames) - -export type ToolName = z.infer - -/** - * ToolUsage - */ - -export const toolUsageSchema = z.record( - toolNamesSchema, - z.object({ - attempts: z.number(), - failures: z.number(), - }), -) - -export type ToolUsage = z.infer - -/** - * RooCodeEvent - */ - -export enum RooCodeEventName { - Message = "message", - TaskCreated = "taskCreated", - TaskStarted = "taskStarted", - TaskModeSwitched = "taskModeSwitched", - TaskPaused = "taskPaused", - TaskUnpaused = "taskUnpaused", - TaskAskResponded = "taskAskResponded", - TaskAborted = "taskAborted", - TaskSpawned = "taskSpawned", - TaskCompleted = "taskCompleted", - TaskTokenUsageUpdated = "taskTokenUsageUpdated", - TaskToolFailed = "taskToolFailed", -} - -export const rooCodeEventsSchema = z.object({ - [RooCodeEventName.Message]: z.tuple([ - z.object({ - taskId: z.string(), - action: z.union([z.literal("created"), z.literal("updated")]), - message: clineMessageSchema, - }), - ]), - [RooCodeEventName.TaskCreated]: z.tuple([z.string()]), - [RooCodeEventName.TaskStarted]: z.tuple([z.string()]), - [RooCodeEventName.TaskModeSwitched]: z.tuple([z.string(), z.string()]), - [RooCodeEventName.TaskPaused]: z.tuple([z.string()]), - [RooCodeEventName.TaskUnpaused]: z.tuple([z.string()]), - [RooCodeEventName.TaskAskResponded]: z.tuple([z.string()]), - [RooCodeEventName.TaskAborted]: z.tuple([z.string()]), - [RooCodeEventName.TaskSpawned]: z.tuple([z.string(), z.string()]), - [RooCodeEventName.TaskCompleted]: z.tuple([z.string(), tokenUsageSchema, toolUsageSchema]), - [RooCodeEventName.TaskTokenUsageUpdated]: z.tuple([z.string(), tokenUsageSchema]), - [RooCodeEventName.TaskToolFailed]: z.tuple([z.string(), toolNamesSchema, z.string()]), -}) - -export type RooCodeEvents = z.infer - -/** - * Ack - */ - -export const ackSchema = z.object({ - clientId: z.string(), - pid: z.number(), - ppid: z.number(), -}) - -export type Ack = z.infer - -/** - * TaskCommand - */ - -export enum TaskCommandName { - StartNewTask = "StartNewTask", - CancelTask = "CancelTask", - CloseTask = "CloseTask", -} - -export const taskCommandSchema = z.discriminatedUnion("commandName", [ - z.object({ - commandName: z.literal(TaskCommandName.StartNewTask), - data: z.object({ - configuration: rooCodeSettingsSchema, - text: z.string(), - images: z.array(z.string()).optional(), - newTab: z.boolean().optional(), - }), - }), - z.object({ - commandName: z.literal(TaskCommandName.CancelTask), - data: z.string(), - }), - z.object({ - commandName: z.literal(TaskCommandName.CloseTask), - data: z.string(), - }), -]) - -export type TaskCommand = z.infer - -/** - * TaskEvent - */ - -export const taskEventSchema = z.discriminatedUnion("eventName", [ - z.object({ - eventName: z.literal(RooCodeEventName.Message), - payload: rooCodeEventsSchema.shape[RooCodeEventName.Message], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskCreated), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskCreated], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskStarted), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskStarted], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskModeSwitched), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskModeSwitched], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskPaused), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskPaused], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskUnpaused), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskUnpaused], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskAskResponded), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskAskResponded], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskAborted), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskAborted], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskSpawned), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskSpawned], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskCompleted), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskCompleted], - }), - z.object({ - eventName: z.literal(RooCodeEventName.TaskTokenUsageUpdated), - payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskTokenUsageUpdated], - }), -]) - -export type TaskEvent = z.infer - -/** - * IpcMessage - */ - -export enum IpcMessageType { - Connect = "Connect", - Disconnect = "Disconnect", - Ack = "Ack", - TaskCommand = "TaskCommand", - TaskEvent = "TaskEvent", -} - -export enum IpcOrigin { - Client = "client", - Server = "server", -} - -export const ipcMessageSchema = z.discriminatedUnion("type", [ - z.object({ - type: z.literal(IpcMessageType.Ack), - origin: z.literal(IpcOrigin.Server), - data: ackSchema, - }), - z.object({ - type: z.literal(IpcMessageType.TaskCommand), - origin: z.literal(IpcOrigin.Client), - clientId: z.string(), - data: taskCommandSchema, - }), - z.object({ - type: z.literal(IpcMessageType.TaskEvent), - origin: z.literal(IpcOrigin.Server), - relayClientId: z.string().optional(), - data: taskEventSchema, - }), -]) - -export type IpcMessage = z.infer diff --git a/packages/types/src/vscode.ts b/packages/types/src/vscode.ts new file mode 100644 index 0000000000..f12b71cdf7 --- /dev/null +++ b/packages/types/src/vscode.ts @@ -0,0 +1,84 @@ +import { z } from "zod" + +/** + * CodeAction + */ + +export const codeActionIds = ["explainCode", "fixCode", "improveCode", "addToContext", "newTask"] as const + +export type CodeActionId = (typeof codeActionIds)[number] + +export type CodeActionName = "EXPLAIN" | "FIX" | "IMPROVE" | "ADD_TO_CONTEXT" | "NEW_TASK" + +/** + * TerminalAction + */ + +export const terminalActionIds = ["terminalAddToContext", "terminalFixCommand", "terminalExplainCommand"] as const + +export type TerminalActionId = (typeof terminalActionIds)[number] + +export type TerminalActionName = "ADD_TO_CONTEXT" | "FIX" | "EXPLAIN" + +export type TerminalActionPromptType = `TERMINAL_${TerminalActionName}` + +/** + * Command + */ + +export const commandIds = [ + "activationCompleted", + + "plusButtonClicked", + "promptsButtonClicked", + "mcpButtonClicked", + "historyButtonClicked", + "popoutButtonClicked", + "settingsButtonClicked", + + "openInNewTab", + + "showHumanRelayDialog", + "registerHumanRelayCallback", + "unregisterHumanRelayCallback", + "handleHumanRelayResponse", + + "newTask", + + "setCustomStoragePath", + + "focusInput", + "acceptInput", +] as const + +export type CommandId = (typeof commandIds)[number] + +/** + * Language + */ + +export const languages = [ + "ca", + "de", + "en", + "es", + "fr", + "hi", + "it", + "ja", + "ko", + "nl", + "pl", + "pt-BR", + "ru", + "tr", + "vi", + "zh-CN", + "zh-TW", +] as const + +export const languagesSchema = z.enum(languages) + +export type Language = z.infer + +export const isLanguage = (value: string): value is Language => languages.includes(value as Language) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index c53385d2e3..c9f2b4a100 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -1323,6 +1323,21 @@ export class Task extends EventEmitter { } finally { this.isStreaming = false } + if ( + inputTokens > 0 || + outputTokens > 0 || + cacheWriteTokens > 0 || + cacheReadTokens > 0 || + typeof totalCost !== "undefined" + ) { + telemetryService.captureLlmCompletion(this.taskId, { + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + cost: totalCost, + }) + } // Need to call here in case the stream was aborted. if (this.abort || this.abandoned) { diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 25db95ac2a..13121531a7 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -15,6 +15,7 @@ import type { ProviderSettings, RooCodeSettings, ProviderSettingsEntry, + TelemetryProperties, CodeActionId, CodeActionName, TerminalActionId, @@ -52,7 +53,7 @@ import { Task, TaskOptions } from "../task/Task" import { getNonce } from "./getNonce" import { getUri } from "./getUri" import { getSystemPromptFilePath } from "../prompts/sections/custom-system-prompt" -import { telemetryService } from "../../services/telemetry/TelemetryService" +import { TelemetryPropertiesProvider, telemetryService } from "../../services/telemetry" import { getWorkspacePath } from "../../utils/path" import { webviewMessageHandler } from "./webviewMessageHandler" import { WebviewMessage } from "../../shared/WebviewMessage" @@ -67,7 +68,10 @@ export type ClineProviderEvents = { clineCreated: [cline: Task] } -export class ClineProvider extends EventEmitter implements vscode.WebviewViewProvider { +export class ClineProvider + extends EventEmitter + implements vscode.WebviewViewProvider, TelemetryPropertiesProvider +{ // Used in package.json as the view's id. This value cannot be changed due // to how VSCode caches views based on their id, and updating the id would // break existing instances of the extension. @@ -1566,59 +1570,21 @@ export class ClineProvider extends EventEmitter implements * This method is called by the telemetry service to get context information * like the current mode, API provider, etc. */ - public async getTelemetryProperties(): Promise> { + public async getTelemetryProperties(): Promise { const { mode, apiConfiguration, language } = await this.getState() - const appVersion = this.context.extension?.packageJSON?.version - const vscodeVersion = vscode.version - const platform = process.platform - const editorName = vscode.env.appName // Get the editor name (VS Code, Cursor, etc.) - - const properties: Record = { - vscodeVersion, - platform, - editorName, - } - - // Add extension version - if (appVersion) { - properties.appVersion = appVersion - } - - // Add language - if (language) { - properties.language = language - } - - // Add current mode - if (mode) { - properties.mode = mode - } - - // Add API provider - if (apiConfiguration?.apiProvider) { - properties.apiProvider = apiConfiguration.apiProvider - } - - // Add model ID if available - const currentCline = this.getCurrentCline() - - if (currentCline?.api) { - const { id: modelId } = currentCline.api.getModel() - - if (modelId) { - properties.modelId = modelId - } - } - - if (currentCline?.diffStrategy) { - properties.diffStrategy = currentCline.diffStrategy.getName() - } + const task = this.getCurrentCline() - // Add isSubtask property that indicates whether this task is a subtask - if (currentCline) { - properties.isSubtask = !!currentCline.parentTask + return { + appVersion: this.context.extension?.packageJSON?.version, + vscodeVersion: vscode.version, + platform: process.platform, + editorName: vscode.env.appName, + language, + mode, + apiProvider: apiConfiguration?.apiProvider, + modelId: task?.api?.getModel().id, + diffStrategy: task?.diffStrategy?.getName(), + isSubtask: task ? !!task.parentTask : undefined, } - - return properties } } diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 0acae75884..1a0b64605d 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -29,7 +29,7 @@ import { getOllamaModels } from "../../api/providers/ollama" import { getVsCodeLmModels } from "../../api/providers/vscode-lm" import { getLmStudioModels } from "../../api/providers/lmstudio" import { openMention } from "../mentions" -import { telemetryService } from "../../services/telemetry/TelemetryService" +import { telemetryService } from "../../services/telemetry" import { TelemetrySetting } from "../../shared/TelemetrySetting" import { getWorkspacePath } from "../../utils/path" import { Mode, defaultModeSlug } from "../../shared/modes" diff --git a/src/extension.ts b/src/extension.ts index db4edd7b26..70d078363d 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -58,7 +58,7 @@ export async function activate(context: vscode.ExtensionContext) { await migrateSettings(context, outputChannel) // Initialize telemetry service after environment variables are loaded. - telemetryService.initialize() + telemetryService.initialize(context) // Initialize i18n for internationalization support initializeI18n(context.globalState.get("language") ?? formatLanguage(vscode.env.language)) diff --git a/src/services/telemetry/PostHogClient.ts b/src/services/telemetry/PostHogClient.ts deleted file mode 100644 index 22fce0beb3..0000000000 --- a/src/services/telemetry/PostHogClient.ts +++ /dev/null @@ -1,150 +0,0 @@ -import { PostHog } from "posthog-node" -import * as vscode from "vscode" - -import { logger } from "../../utils/logging" - -// This forward declaration is needed to avoid circular dependencies -export interface ClineProviderInterface { - // Gets telemetry properties to attach to every event - getTelemetryProperties(): Promise> -} - -/** - * PostHogClient handles telemetry event tracking for the Roo Code extension - * Uses PostHog analytics to track user interactions and system events - * Respects user privacy settings and VSCode's global telemetry configuration - */ -export class PostHogClient { - public static readonly EVENTS = { - TASK: { - CREATED: "Task Created", - RESTARTED: "Task Reopened", - COMPLETED: "Task Completed", - CONVERSATION_MESSAGE: "Conversation Message", - MODE_SWITCH: "Mode Switched", - TOOL_USED: "Tool Used", - CHECKPOINT_CREATED: "Checkpoint Created", - CHECKPOINT_RESTORED: "Checkpoint Restored", - CHECKPOINT_DIFFED: "Checkpoint Diffed", - CODE_ACTION_USED: "Code Action Used", - PROMPT_ENHANCED: "Prompt Enhanced", - CONTEXT_CONDENSED: "Context Condensed", - SLIDING_WINDOW_TRUNCATION: "Sliding Window Truncation", - }, - ERRORS: { - SCHEMA_VALIDATION_ERROR: "Schema Validation Error", - DIFF_APPLICATION_ERROR: "Diff Application Error", - SHELL_INTEGRATION_ERROR: "Shell Integration Error", - CONSECUTIVE_MISTAKE_ERROR: "Consecutive Mistake Error", - }, - } - - private static instance: PostHogClient - private client: PostHog - private distinctId: string = vscode.env.machineId - private telemetryEnabled: boolean = false - private providerRef: WeakRef | null = null - - private constructor() { - this.client = new PostHog(process.env.POSTHOG_API_KEY || "", { host: "https://us.i.posthog.com" }) - } - - /** - * Updates the telemetry state based on user preferences and VSCode settings - * Only enables telemetry if both VSCode global telemetry is enabled and user has opted in - * @param didUserOptIn Whether the user has explicitly opted into telemetry - */ - public updateTelemetryState(didUserOptIn: boolean): void { - this.telemetryEnabled = false - - // First check global telemetry level - telemetry should only be enabled when level is "all" - const telemetryLevel = vscode.workspace.getConfiguration("telemetry").get("telemetryLevel", "all") - const globalTelemetryEnabled = telemetryLevel === "all" - - // We only enable telemetry if global vscode telemetry is enabled - if (globalTelemetryEnabled) { - this.telemetryEnabled = didUserOptIn - } - - // Update PostHog client state based on telemetry preference - if (this.telemetryEnabled) { - this.client.optIn() - } else { - this.client.optOut() - } - } - - /** - * Gets or creates the singleton instance of PostHogClient - * @returns The PostHogClient instance - */ - public static getInstance(): PostHogClient { - if (!PostHogClient.instance) { - PostHogClient.instance = new PostHogClient() - } - - return PostHogClient.instance - } - - /** - * Sets the ClineProvider reference to use for global properties - * @param provider A ClineProvider instance to use - */ - public setProvider(provider: ClineProviderInterface): void { - this.providerRef = new WeakRef(provider) - logger.debug("PostHogClient: ClineProvider reference set") - } - - /** - * Captures a telemetry event if telemetry is enabled - * @param event The event to capture with its properties - */ - public async capture(event: { event: string; properties?: any }): Promise { - // Only send events if telemetry is enabled - if (this.telemetryEnabled) { - // Get global properties from ClineProvider if available - let globalProperties: Record = {} - const provider = this.providerRef?.deref() - - if (provider) { - try { - // Get the telemetry properties directly from the provider - globalProperties = await provider.getTelemetryProperties() - } catch (error) { - // Log error but continue with capturing the event - logger.error( - `Error getting telemetry properties: ${error instanceof Error ? error.message : String(error)}`, - ) - } - } - - // Merge global properties with event-specific properties - // Event properties take precedence in case of conflicts - const mergedProperties = { - ...globalProperties, - ...(event.properties || {}), - } - - this.client.capture({ - distinctId: this.distinctId, - event: event.event, - properties: mergedProperties, - }) - } - } - - /** - * Checks if telemetry is currently enabled - * @returns Whether telemetry is enabled - */ - public isTelemetryEnabled(): boolean { - return this.telemetryEnabled - } - - /** - * Shuts down the PostHog client - */ - public async shutdown(): Promise { - await this.client.shutdown() - } -} diff --git a/src/services/telemetry/TelemetryService.ts b/src/services/telemetry/TelemetryService.ts index c0ddbe4edc..cc1248f1b7 100644 --- a/src/services/telemetry/TelemetryService.ts +++ b/src/services/telemetry/TelemetryService.ts @@ -1,28 +1,35 @@ +import * as vscode from "vscode" import { ZodError } from "zod" +import { TelemetryEventName } from "@roo-code/types" + import { logger } from "../../utils/logging" -import { PostHogClient, ClineProviderInterface } from "./PostHogClient" + +import { PostHogTelemetryClient } from "./clients/PostHogTelemetryClient" +import { type TelemetryClient, type TelemetryPropertiesProvider } from "./types" /** - * TelemetryService wrapper class that defers PostHogClient initialization - * This ensures that we only create the PostHogClient after environment variables are loaded + * TelemetryService wrapper class that defers initialization. + * This ensures that we only create the various clients after environment + * variables are loaded. */ class TelemetryService { - private client: PostHogClient | null = null + private clients: TelemetryClient[] = [] private initialized = false /** - * Initialize the telemetry service with the PostHog client - * This should be called after environment variables are loaded + * Initialize the telemetry client. This should be called after environment + * variables are loaded. */ - public initialize(): void { + public async initialize(context: vscode.ExtensionContext): Promise { if (this.initialized) { return } + this.initialized = true + try { - this.client = PostHogClient.getInstance() - this.initialized = true + this.clients.push(PostHogTelemetryClient.getInstance()) } catch (error) { console.warn("Failed to initialize telemetry service:", error) } @@ -32,10 +39,10 @@ class TelemetryService { * Sets the ClineProvider reference to use for global properties * @param provider A ClineProvider instance to use */ - public setProvider(provider: ClineProviderInterface): void { - // If client is initialized, pass the provider reference + public setProvider(provider: TelemetryPropertiesProvider): void { + // If client is initialized, pass the provider reference. if (this.isReady) { - this.client!.setProvider(provider) + this.clients.forEach((client) => client.setProvider(provider)) } logger.debug("TelemetryService: ClineProvider reference set") @@ -47,7 +54,7 @@ class TelemetryService { * @returns Whether the service is ready to use */ private get isReady(): boolean { - return this.initialized && this.client !== null + return this.initialized && this.clients.length > 0 } /** @@ -59,65 +66,69 @@ class TelemetryService { return } - this.client!.updateTelemetryState(didUserOptIn) + this.clients.forEach((client) => client.updateTelemetryState(didUserOptIn)) } /** - * Captures a telemetry event if telemetry is enabled - * @param event The event to capture with its properties + * Generic method to capture any type of event with specified properties + * @param eventName The event name to capture + * @param properties The event properties */ - public capture(event: { event: string; properties?: any }): void { + public captureEvent(eventName: TelemetryEventName, properties?: any): void { if (!this.isReady) { return } - this.client!.capture(event) + this.clients.forEach((client) => client.capture({ event: eventName, properties })) } - /** - * Generic method to capture any type of event with specified properties - * @param eventName The event name to capture - * @param properties The event properties - */ - public captureEvent(eventName: string, properties?: any): void { - this.capture({ event: eventName, properties }) - } - - // Task events convenience methods public captureTaskCreated(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CREATED, { taskId }) + this.captureEvent(TelemetryEventName.TASK_CREATED, { taskId }) } public captureTaskRestarted(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.RESTARTED, { taskId }) + this.captureEvent(TelemetryEventName.TASK_RESTARTED, { taskId }) } public captureTaskCompleted(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.COMPLETED, { taskId }) + this.captureEvent(TelemetryEventName.TASK_COMPLETED, { taskId }) } public captureConversationMessage(taskId: string, source: "user" | "assistant"): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CONVERSATION_MESSAGE, { taskId, source }) + this.captureEvent(TelemetryEventName.TASK_CONVERSATION_MESSAGE, { taskId, source }) + } + + public captureLlmCompletion( + taskId: string, + properties: { + inputTokens: number + outputTokens: number + cacheWriteTokens: number + cacheReadTokens: number + cost?: number + }, + ): void { + this.captureEvent(TelemetryEventName.LLM_COMPLETION, { taskId, ...properties }) } public captureModeSwitch(taskId: string, newMode: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.MODE_SWITCH, { taskId, newMode }) + this.captureEvent(TelemetryEventName.MODE_SWITCH, { taskId, newMode }) } public captureToolUsage(taskId: string, tool: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.TOOL_USED, { taskId, tool }) + this.captureEvent(TelemetryEventName.TOOL_USED, { taskId, tool }) } public captureCheckpointCreated(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CHECKPOINT_CREATED, { taskId }) + this.captureEvent(TelemetryEventName.CHECKPOINT_CREATED, { taskId }) } public captureCheckpointDiffed(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CHECKPOINT_DIFFED, { taskId }) + this.captureEvent(TelemetryEventName.CHECKPOINT_DIFFED, { taskId }) } public captureCheckpointRestored(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CHECKPOINT_RESTORED, { taskId }) + this.captureEvent(TelemetryEventName.CHECKPOINT_RESTORED, { taskId }) } public captureContextCondensed( @@ -126,7 +137,7 @@ class TelemetryService { usedCustomPrompt?: boolean, usedCustomApiHandler?: boolean, ): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CONTEXT_CONDENSED, { + this.captureEvent(TelemetryEventName.CONTEXT_CONDENSED, { taskId, isAutomaticTrigger, ...(usedCustomPrompt !== undefined && { usedCustomPrompt }), @@ -135,32 +146,32 @@ class TelemetryService { } public captureSlidingWindowTruncation(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.SLIDING_WINDOW_TRUNCATION, { taskId }) + this.captureEvent(TelemetryEventName.SLIDING_WINDOW_TRUNCATION, { taskId }) } public captureCodeActionUsed(actionType: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.CODE_ACTION_USED, { actionType }) + this.captureEvent(TelemetryEventName.CODE_ACTION_USED, { actionType }) } public capturePromptEnhanced(taskId?: string): void { - this.captureEvent(PostHogClient.EVENTS.TASK.PROMPT_ENHANCED, { ...(taskId && { taskId }) }) + this.captureEvent(TelemetryEventName.PROMPT_ENHANCED, { ...(taskId && { taskId }) }) } public captureSchemaValidationError({ schemaName, error }: { schemaName: string; error: ZodError }): void { // https://zod.dev/ERROR_HANDLING?id=formatting-errors - this.captureEvent(PostHogClient.EVENTS.ERRORS.SCHEMA_VALIDATION_ERROR, { schemaName, error: error.format() }) + this.captureEvent(TelemetryEventName.SCHEMA_VALIDATION_ERROR, { schemaName, error: error.format() }) } public captureDiffApplicationError(taskId: string, consecutiveMistakeCount: number): void { - this.captureEvent(PostHogClient.EVENTS.ERRORS.DIFF_APPLICATION_ERROR, { taskId, consecutiveMistakeCount }) + this.captureEvent(TelemetryEventName.DIFF_APPLICATION_ERROR, { taskId, consecutiveMistakeCount }) } public captureShellIntegrationError(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.ERRORS.SHELL_INTEGRATION_ERROR, { taskId }) + this.captureEvent(TelemetryEventName.SHELL_INTEGRATION_ERROR, { taskId }) } public captureConsecutiveMistakeError(taskId: string): void { - this.captureEvent(PostHogClient.EVENTS.ERRORS.CONSECUTIVE_MISTAKE_ERROR, { taskId }) + this.captureEvent(TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR, { taskId }) } /** @@ -168,7 +179,7 @@ class TelemetryService { * @param button The button that was clicked */ public captureTitleButtonClicked(button: string): void { - this.captureEvent("Title Button Clicked", { button }) + this.captureEvent(TelemetryEventName.TITLE_BUTTON_CLICKED, { button }) } /** @@ -176,20 +187,16 @@ class TelemetryService { * @returns Whether telemetry is enabled */ public isTelemetryEnabled(): boolean { - return this.isReady && this.client!.isTelemetryEnabled() + return this.isReady && this.clients.some((client) => client.isTelemetryEnabled()) } - /** - * Shuts down the PostHog client - */ public async shutdown(): Promise { if (!this.isReady) { return } - await this.client!.shutdown() + this.clients.forEach((client) => client.shutdown()) } } -// Export a singleton instance of the telemetry service wrapper export const telemetryService = new TelemetryService() diff --git a/src/services/telemetry/clients/BaseTelemetryClient.ts b/src/services/telemetry/clients/BaseTelemetryClient.ts new file mode 100644 index 0000000000..24a486a2ea --- /dev/null +++ b/src/services/telemetry/clients/BaseTelemetryClient.ts @@ -0,0 +1,58 @@ +import { TelemetryEvent, TelemetryEventName } from "@roo-code/types" + +import { TelemetryClient, TelemetryPropertiesProvider, TelemetryEventSubscription } from "../types" + +export abstract class BaseTelemetryClient implements TelemetryClient { + protected providerRef: WeakRef | null = null + protected telemetryEnabled: boolean = false + + constructor( + public readonly subscription?: TelemetryEventSubscription, + protected readonly debug = false, + ) {} + + protected isEventCapturable(eventName: TelemetryEventName): boolean { + if (!this.subscription) { + return true + } + + return this.subscription.type === "include" + ? this.subscription.events.includes(eventName) + : !this.subscription.events.includes(eventName) + } + + protected async getEventProperties(event: TelemetryEvent): Promise { + let providerProperties: TelemetryEvent["properties"] = {} + const provider = this.providerRef?.deref() + + if (provider) { + try { + // Get the telemetry properties directly from the provider. + providerProperties = await provider.getTelemetryProperties() + } catch (error) { + // Log error but continue with capturing the event. + console.error( + `Error getting telemetry properties: ${error instanceof Error ? error.message : String(error)}`, + ) + } + } + + // Merge provider properties with event-specific properties. + // Event properties take precedence in case of conflicts. + return { ...providerProperties, ...(event.properties || {}) } + } + + public abstract capture(event: TelemetryEvent): Promise + + public setProvider(provider: TelemetryPropertiesProvider): void { + this.providerRef = new WeakRef(provider) + } + + public abstract updateTelemetryState(didUserOptIn: boolean): void + + public isTelemetryEnabled(): boolean { + return this.telemetryEnabled + } + + public abstract shutdown(): Promise +} diff --git a/src/services/telemetry/clients/PostHogTelemetryClient.ts b/src/services/telemetry/clients/PostHogTelemetryClient.ts new file mode 100644 index 0000000000..b554d962e3 --- /dev/null +++ b/src/services/telemetry/clients/PostHogTelemetryClient.ts @@ -0,0 +1,88 @@ +import { PostHog } from "posthog-node" +import * as vscode from "vscode" + +import { TelemetryEventName, type TelemetryEvent } from "@roo-code/types" + +import { BaseTelemetryClient } from "./BaseTelemetryClient" + +/** + * PostHogTelemetryClient handles telemetry event tracking for the Roo Code extension. + * Uses PostHog analytics to track user interactions and system events. + * Respects user privacy settings and VSCode's global telemetry configuration. + */ +export class PostHogTelemetryClient extends BaseTelemetryClient { + private client: PostHog + private distinctId: string = vscode.env.machineId + + private constructor(debug = false) { + super( + { + type: "exclude", + events: [TelemetryEventName.LLM_COMPLETION], + }, + debug, + ) + + this.client = new PostHog(process.env.POSTHOG_API_KEY || "", { host: "https://us.i.posthog.com" }) + } + + public override async capture(event: TelemetryEvent): Promise { + if (!this.isTelemetryEnabled() || !this.isEventCapturable(event.event)) { + if (this.debug) { + console.info(`[PostHogTelemetryClient#capture] Skipping event: ${event.event}`) + } + + return + } + + if (this.debug) { + console.info(`[PostHogTelemetryClient#capture] ${event.event}`) + } + + this.client.capture({ + distinctId: this.distinctId, + event: event.event, + properties: await this.getEventProperties(event), + }) + } + + /** + * Updates the telemetry state based on user preferences and VSCode settings. + * Only enables telemetry if both VSCode global telemetry is enabled and + * user has opted in. + * @param didUserOptIn Whether the user has explicitly opted into telemetry + */ + public override updateTelemetryState(didUserOptIn: boolean): void { + this.telemetryEnabled = false + + // First check global telemetry level - telemetry should only be enabled when level is "all". + const telemetryLevel = vscode.workspace.getConfiguration("telemetry").get("telemetryLevel", "all") + const globalTelemetryEnabled = telemetryLevel === "all" + + // We only enable telemetry if global vscode telemetry is enabled. + if (globalTelemetryEnabled) { + this.telemetryEnabled = didUserOptIn + } + + // Update PostHog client state based on telemetry preference. + if (this.telemetryEnabled) { + this.client.optIn() + } else { + this.client.optOut() + } + } + + public override async shutdown(): Promise { + await this.client.shutdown() + } + + private static _instance: PostHogTelemetryClient | null = null + + public static getInstance(): PostHogTelemetryClient { + if (!PostHogTelemetryClient._instance) { + PostHogTelemetryClient._instance = new PostHogTelemetryClient() + } + + return PostHogTelemetryClient._instance + } +} diff --git a/src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts b/src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts new file mode 100644 index 0000000000..89bb16e81a --- /dev/null +++ b/src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts @@ -0,0 +1,270 @@ +// npx jest src/services/telemetry/clients/__tests__/PostHogTelemetryClient.test.ts + +import * as vscode from "vscode" +import { PostHog } from "posthog-node" + +import { TelemetryEventName } from "@roo-code/types" + +import { TelemetryPropertiesProvider } from "../../types" +import { PostHogTelemetryClient } from "../PostHogTelemetryClient" + +jest.mock("posthog-node") + +jest.mock("vscode", () => ({ + env: { + machineId: "test-machine-id", + }, + workspace: { + getConfiguration: jest.fn(), + }, +})) + +describe("PostHogTelemetryClient", () => { + const getPrivateProperty = (instance: any, propertyName: string): T => { + return instance[propertyName] + } + + let mockPostHogClient: jest.Mocked + + beforeEach(() => { + jest.clearAllMocks() + + mockPostHogClient = { + capture: jest.fn(), + optIn: jest.fn(), + optOut: jest.fn(), + shutdown: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked + ;(PostHog as unknown as jest.Mock).mockImplementation(() => mockPostHogClient) + + // @ts-ignore - Accessing private static property for testing + PostHogTelemetryClient._instance = undefined + ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ + get: jest.fn().mockReturnValue("all"), + }) + }) + + describe("getInstance", () => { + it("should return the same instance when called multiple times", () => { + const instance1 = PostHogTelemetryClient.getInstance() + const instance2 = PostHogTelemetryClient.getInstance() + expect(instance1).toBe(instance2) + }) + }) + + describe("isEventCapturable", () => { + it("should return true for events not in exclude list", () => { + const client = PostHogTelemetryClient.getInstance() + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.TASK_CREATED)).toBe(true) + expect(isEventCapturable(TelemetryEventName.MODE_SWITCH)).toBe(true) + }) + + it("should return false for events in exclude list", () => { + const client = PostHogTelemetryClient.getInstance() + + const isEventCapturable = getPrivateProperty<(eventName: TelemetryEventName) => boolean>( + client, + "isEventCapturable", + ).bind(client) + + expect(isEventCapturable(TelemetryEventName.LLM_COMPLETION)).toBe(false) + }) + }) + + describe("getEventProperties", () => { + it("should merge provider properties with event properties", async () => { + const client = PostHogTelemetryClient.getInstance() + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: jest.fn().mockResolvedValue({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + }), + } + + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { + customProp: "value", + mode: "override", // This should override the provider's mode. + }, + }) + + expect(result).toEqual({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "override", // Event property takes precedence. + customProp: "value", + }) + + expect(mockProvider.getTelemetryProperties).toHaveBeenCalledTimes(1) + }) + + it("should handle errors from provider gracefully", async () => { + const client = PostHogTelemetryClient.getInstance() + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: jest.fn().mockRejectedValue(new Error("Provider error")), + } + + const consoleErrorSpy = jest.spyOn(console, "error").mockImplementation() + client.setProvider(mockProvider) + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining("Error getting telemetry properties: Provider error"), + ) + + consoleErrorSpy.mockRestore() + }) + + it("should return event properties when no provider is set", async () => { + const client = PostHogTelemetryClient.getInstance() + + const getEventProperties = getPrivateProperty< + (event: { event: TelemetryEventName; properties?: Record }) => Promise> + >(client, "getEventProperties").bind(client) + + const result = await getEventProperties({ + event: TelemetryEventName.TASK_CREATED, + properties: { customProp: "value" }, + }) + + expect(result).toEqual({ customProp: "value" }) + }) + }) + + describe("capture", () => { + it("should not capture events when telemetry is disabled", async () => { + const client = PostHogTelemetryClient.getInstance() + client.updateTelemetryState(false) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }) + + expect(mockPostHogClient.capture).not.toHaveBeenCalled() + }) + + it("should not capture events that are not capturable", async () => { + const client = PostHogTelemetryClient.getInstance() + client.updateTelemetryState(true) + + await client.capture({ + event: TelemetryEventName.LLM_COMPLETION, // This is in the exclude list. + properties: { test: "value" }, + }) + + expect(mockPostHogClient.capture).not.toHaveBeenCalled() + }) + + it("should capture events when telemetry is enabled and event is capturable", async () => { + const client = PostHogTelemetryClient.getInstance() + client.updateTelemetryState(true) + + const mockProvider: TelemetryPropertiesProvider = { + getTelemetryProperties: jest.fn().mockResolvedValue({ + appVersion: "1.0.0", + vscodeVersion: "1.60.0", + platform: "darwin", + editorName: "vscode", + language: "en", + mode: "code", + }), + } + + client.setProvider(mockProvider) + + await client.capture({ + event: TelemetryEventName.TASK_CREATED, + properties: { test: "value" }, + }) + + expect(mockPostHogClient.capture).toHaveBeenCalledWith({ + distinctId: "test-machine-id", + event: TelemetryEventName.TASK_CREATED, + properties: expect.objectContaining({ + appVersion: "1.0.0", + test: "value", + }), + }) + }) + }) + + describe("updateTelemetryState", () => { + it("should enable telemetry when user opts in and global telemetry is enabled", () => { + const client = PostHogTelemetryClient.getInstance() + + ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ + get: jest.fn().mockReturnValue("all"), + }) + + client.updateTelemetryState(true) + + expect(client.isTelemetryEnabled()).toBe(true) + expect(mockPostHogClient.optIn).toHaveBeenCalled() + }) + + it("should disable telemetry when user opts out", () => { + const client = PostHogTelemetryClient.getInstance() + + ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ + get: jest.fn().mockReturnValue("all"), + }) + + client.updateTelemetryState(false) + + expect(client.isTelemetryEnabled()).toBe(false) + expect(mockPostHogClient.optOut).toHaveBeenCalled() + }) + + it("should disable telemetry when global telemetry is disabled, regardless of user opt-in", () => { + const client = PostHogTelemetryClient.getInstance() + + ;(vscode.workspace.getConfiguration as jest.Mock).mockReturnValue({ + get: jest.fn().mockReturnValue("off"), + }) + + client.updateTelemetryState(true) + expect(client.isTelemetryEnabled()).toBe(false) + expect(mockPostHogClient.optOut).toHaveBeenCalled() + }) + }) + + describe("shutdown", () => { + it("should call shutdown on the PostHog client", async () => { + const client = PostHogTelemetryClient.getInstance() + await client.shutdown() + expect(mockPostHogClient.shutdown).toHaveBeenCalled() + }) + }) +}) diff --git a/src/services/telemetry/index.ts b/src/services/telemetry/index.ts new file mode 100644 index 0000000000..6700a8085e --- /dev/null +++ b/src/services/telemetry/index.ts @@ -0,0 +1,2 @@ +export * from "./TelemetryService" +export * from "./types" diff --git a/src/services/telemetry/types.ts b/src/services/telemetry/types.ts new file mode 100644 index 0000000000..b6fb038484 --- /dev/null +++ b/src/services/telemetry/types.ts @@ -0,0 +1,19 @@ +import { TelemetryEventName, type TelemetryProperties, type TelemetryEvent } from "@roo-code/types" + +export type TelemetryEventSubscription = + | { type: "include"; events: TelemetryEventName[] } + | { type: "exclude"; events: TelemetryEventName[] } + +export interface TelemetryPropertiesProvider { + getTelemetryProperties(): Promise +} + +export interface TelemetryClient { + subscription?: TelemetryEventSubscription + + setProvider(provider: TelemetryPropertiesProvider): void + capture(options: TelemetryEvent): Promise + updateTelemetryState(didUserOptIn: boolean): void + isTelemetryEnabled(): boolean + shutdown(): Promise +} diff --git a/src/utils/__tests__/refresh-timer.test.ts b/src/utils/__tests__/refresh-timer.test.ts new file mode 100644 index 0000000000..11911494f6 --- /dev/null +++ b/src/utils/__tests__/refresh-timer.test.ts @@ -0,0 +1,210 @@ +import { RefreshTimer } from "../refresh-timer" + +// Mock timers +jest.useFakeTimers() + +describe("RefreshTimer", () => { + let mockCallback: jest.Mock + let refreshTimer: RefreshTimer + + beforeEach(() => { + // Reset mocks before each test + mockCallback = jest.fn() + + // Default mock implementation returns success + mockCallback.mockResolvedValue(true) + }) + + afterEach(() => { + // Clean up after each test + if (refreshTimer) { + refreshTimer.stop() + } + jest.clearAllTimers() + jest.clearAllMocks() + }) + + it("should execute callback immediately when started", () => { + refreshTimer = new RefreshTimer({ + callback: mockCallback, + }) + + refreshTimer.start() + + expect(mockCallback).toHaveBeenCalledTimes(1) + }) + + it("should schedule next attempt after success interval when callback succeeds", async () => { + mockCallback.mockResolvedValue(true) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + successInterval: 50000, // 50 seconds + }) + + refreshTimer.start() + + // Fast-forward to execute the first callback + await Promise.resolve() + + expect(mockCallback).toHaveBeenCalledTimes(1) + + // Fast-forward 50 seconds + jest.advanceTimersByTime(50000) + + // Callback should be called again + expect(mockCallback).toHaveBeenCalledTimes(2) + }) + + it("should use exponential backoff when callback fails", async () => { + mockCallback.mockResolvedValue(false) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + initialBackoffMs: 1000, // 1 second + }) + + refreshTimer.start() + + // Fast-forward to execute the first callback + await Promise.resolve() + + expect(mockCallback).toHaveBeenCalledTimes(1) + + // Fast-forward 1 second + jest.advanceTimersByTime(1000) + + // Callback should be called again + expect(mockCallback).toHaveBeenCalledTimes(2) + + // Fast-forward to execute the second callback + await Promise.resolve() + + // Fast-forward 2 seconds + jest.advanceTimersByTime(2000) + + // Callback should be called again + expect(mockCallback).toHaveBeenCalledTimes(3) + + // Fast-forward to execute the third callback + await Promise.resolve() + }) + + it("should not exceed maximum backoff interval", async () => { + mockCallback.mockResolvedValue(false) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + initialBackoffMs: 1000, // 1 second + maxBackoffMs: 5000, // 5 seconds + }) + + refreshTimer.start() + + // Fast-forward through multiple failures to reach max backoff + await Promise.resolve() // First attempt + jest.advanceTimersByTime(1000) + + await Promise.resolve() // Second attempt (backoff = 2000ms) + jest.advanceTimersByTime(2000) + + await Promise.resolve() // Third attempt (backoff = 4000ms) + jest.advanceTimersByTime(4000) + + await Promise.resolve() // Fourth attempt (backoff would be 8000ms but max is 5000ms) + + // Should be capped at maxBackoffMs (no way to verify without logger) + }) + + it("should reset backoff after a successful attempt", async () => { + // First call fails, second succeeds, third fails + mockCallback.mockResolvedValueOnce(false).mockResolvedValueOnce(true).mockResolvedValueOnce(false) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + initialBackoffMs: 1000, + successInterval: 5000, + }) + + refreshTimer.start() + + // First attempt (fails) + await Promise.resolve() + + // Fast-forward 1 second + jest.advanceTimersByTime(1000) + + // Second attempt (succeeds) + await Promise.resolve() + + // Fast-forward 5 seconds + jest.advanceTimersByTime(5000) + + // Third attempt (fails) + await Promise.resolve() + + // Backoff should be reset to initial value (no way to verify without logger) + }) + + it("should handle errors in callback as failures", async () => { + mockCallback.mockRejectedValue(new Error("Test error")) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + initialBackoffMs: 1000, + }) + + refreshTimer.start() + + // Fast-forward to execute the callback + await Promise.resolve() + + // Error should be treated as a failure (no way to verify without logger) + }) + + it("should stop the timer and cancel pending executions", () => { + refreshTimer = new RefreshTimer({ + callback: mockCallback, + }) + + refreshTimer.start() + + // Stop the timer + refreshTimer.stop() + + // Fast-forward a long time + jest.advanceTimersByTime(1000000) + + // Callback should only have been called once (the initial call) + expect(mockCallback).toHaveBeenCalledTimes(1) + }) + + it("should reset the backoff state", async () => { + mockCallback.mockResolvedValue(false) + + refreshTimer = new RefreshTimer({ + callback: mockCallback, + initialBackoffMs: 1000, + }) + + refreshTimer.start() + + // Fast-forward through a few failures + await Promise.resolve() + jest.advanceTimersByTime(1000) + + await Promise.resolve() + jest.advanceTimersByTime(2000) + + // Reset the timer + refreshTimer.reset() + + // Stop and restart to trigger a new execution + refreshTimer.stop() + refreshTimer.start() + + await Promise.resolve() + + // Backoff should be back to initial value (no way to verify without logger) + }) +}) diff --git a/src/utils/refresh-timer.ts b/src/utils/refresh-timer.ts new file mode 100644 index 0000000000..3138031665 --- /dev/null +++ b/src/utils/refresh-timer.ts @@ -0,0 +1,154 @@ +/** + * RefreshTimer - A utility for executing a callback with configurable retry behavior + * + * This timer executes a callback function and schedules the next execution based on the result: + * - If the callback succeeds (returns true), it schedules the next attempt after a fixed interval + * - If the callback fails (returns false), it uses exponential backoff up to a maximum interval + */ + +/** + * Configuration options for the RefreshTimer + */ +export interface RefreshTimerOptions { + /** + * The callback function to execute + * Should return a Promise that resolves to a boolean indicating success (true) or failure (false) + */ + callback: () => Promise + + /** + * Time in milliseconds to wait before next attempt after success + * @default 50000 (50 seconds) + */ + successInterval?: number + + /** + * Initial backoff time in milliseconds for the first failure + * @default 1000 (1 second) + */ + initialBackoffMs?: number + + /** + * Maximum backoff time in milliseconds + * @default 300000 (5 minutes) + */ + maxBackoffMs?: number +} + +/** + * A timer utility that executes a callback with configurable retry behavior + */ +export class RefreshTimer { + private callback: () => Promise + private successInterval: number + private initialBackoffMs: number + private maxBackoffMs: number + private currentBackoffMs: number + private attemptCount: number + private timerId: NodeJS.Timeout | null + private isRunning: boolean + + /** + * Creates a new RefreshTimer + * + * @param options Configuration options for the timer + */ + constructor(options: RefreshTimerOptions) { + this.callback = options.callback + this.successInterval = options.successInterval ?? 50000 // 50 seconds + this.initialBackoffMs = options.initialBackoffMs ?? 1000 // 1 second + this.maxBackoffMs = options.maxBackoffMs ?? 300000 // 5 minutes + this.currentBackoffMs = this.initialBackoffMs + this.attemptCount = 0 + this.timerId = null + this.isRunning = false + } + + /** + * Starts the timer and executes the callback immediately + */ + public start(): void { + if (this.isRunning) { + return + } + + this.isRunning = true + + // Execute the callback immediately + this.executeCallback() + } + + /** + * Stops the timer and cancels any pending execution + */ + public stop(): void { + if (!this.isRunning) { + return + } + + if (this.timerId) { + clearTimeout(this.timerId) + this.timerId = null + } + + this.isRunning = false + } + + /** + * Resets the backoff state and attempt count + * Does not affect whether the timer is running + */ + public reset(): void { + this.currentBackoffMs = this.initialBackoffMs + this.attemptCount = 0 + } + + /** + * Schedules the next attempt based on the success/failure of the current attempt + * + * @param wasSuccessful Whether the current attempt was successful + */ + private scheduleNextAttempt(wasSuccessful: boolean): void { + if (!this.isRunning) { + return + } + + if (wasSuccessful) { + // Reset backoff on success + this.currentBackoffMs = this.initialBackoffMs + this.attemptCount = 0 + + this.timerId = setTimeout(() => this.executeCallback(), this.successInterval) + } else { + // Increment attempt count + this.attemptCount++ + + // Calculate backoff time with exponential increase + // Formula: initialBackoff * 2^(attemptCount - 1) + this.currentBackoffMs = Math.min( + this.initialBackoffMs * Math.pow(2, this.attemptCount - 1), + this.maxBackoffMs, + ) + + this.timerId = setTimeout(() => this.executeCallback(), this.currentBackoffMs) + } + } + + /** + * Executes the callback and handles the result + */ + private async executeCallback(): Promise { + if (!this.isRunning) { + return + } + + try { + const result = await this.callback() + + this.scheduleNextAttempt(result) + } catch (error) { + // Treat errors as failed attempts + this.scheduleNextAttempt(false) + } + } +}