Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions packages/cloud/src/bridge/BridgeOrchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,28 @@ export class BridgeOrchestrator {
return BridgeOrchestrator.instance
}

public static isEnabled(user?: CloudUserInfo | null, remoteControlEnabled?: boolean): boolean {
return !!(user?.id && user.extensionBridgeEnabled && remoteControlEnabled)
public static isEnabled(user: CloudUserInfo | null, remoteControlEnabled: boolean): boolean {
// Always disabled if signed out.
if (!user) {
return false
}

// Disabled by the user's organization?
if (!user.extensionBridgeEnabled) {
return false
}

// Disabled by the user?
if (!remoteControlEnabled) {
return false
}

return true
}

public static async connectOrDisconnect(
userInfo: CloudUserInfo | null,
remoteControlEnabled: boolean | undefined,
userInfo: CloudUserInfo,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? The connectOrDisconnect() method now requires a non-null userInfo parameter, but it's being called from ClineProvider.remoteControlEnabled() where userInfo could be null. This type mismatch could lead to runtime errors.

Consider either:

  1. Making userInfo nullable here: userInfo: CloudUserInfo | null
  2. Or ensuring all callers check for null before calling this method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no type mismatches; the build would fail otherwise.

remoteControlEnabled: boolean,
options: BridgeOrchestratorOptions,
): Promise<void> {
if (BridgeOrchestrator.isEnabled(userInfo, remoteControlEnabled)) {
Expand Down
257 changes: 257 additions & 0 deletions src/__tests__/extension.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// npx vitest run __tests__/extension.spec.ts

import type * as vscode from "vscode"
import type { AuthState } from "@roo-code/types"

vi.mock("vscode", () => ({
window: {
createOutputChannel: vi.fn().mockReturnValue({
appendLine: vi.fn(),
}),
registerWebviewViewProvider: vi.fn(),
registerUriHandler: vi.fn(),
tabGroups: {
onDidChangeTabs: vi.fn(),
},
onDidChangeActiveTextEditor: vi.fn(),
},
workspace: {
registerTextDocumentContentProvider: vi.fn(),
getConfiguration: vi.fn().mockReturnValue({
get: vi.fn().mockReturnValue([]),
}),
createFileSystemWatcher: vi.fn().mockReturnValue({
onDidCreate: vi.fn(),
onDidChange: vi.fn(),
onDidDelete: vi.fn(),
dispose: vi.fn(),
}),
onDidChangeWorkspaceFolders: vi.fn(),
},
languages: {
registerCodeActionsProvider: vi.fn(),
},
commands: {
executeCommand: vi.fn(),
},
env: {
language: "en",
},
ExtensionMode: {
Production: 1,
},
}))

vi.mock("@dotenvx/dotenvx", () => ({
config: vi.fn(),
}))

const mockBridgeOrchestratorDisconnect = vi.fn().mockResolvedValue(undefined)

vi.mock("@roo-code/cloud", () => ({
CloudService: {
createInstance: vi.fn(),
hasInstance: vi.fn().mockReturnValue(true),
get instance() {
return {
off: vi.fn(),
on: vi.fn(),
getUserInfo: vi.fn().mockReturnValue(null),
isTaskSyncEnabled: vi.fn().mockReturnValue(false),
}
},
},
BridgeOrchestrator: {
disconnect: mockBridgeOrchestratorDisconnect,
},
getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"),
}))

vi.mock("@roo-code/telemetry", () => ({
TelemetryService: {
createInstance: vi.fn().mockReturnValue({
register: vi.fn(),
setProvider: vi.fn(),
shutdown: vi.fn(),
}),
get instance() {
return {
register: vi.fn(),
setProvider: vi.fn(),
shutdown: vi.fn(),
}
},
},
PostHogTelemetryClient: vi.fn(),
}))

vi.mock("../utils/outputChannelLogger", () => ({
createOutputChannelLogger: vi.fn().mockReturnValue(vi.fn()),
createDualLogger: vi.fn().mockReturnValue(vi.fn()),
}))

vi.mock("../shared/package", () => ({
Package: {
name: "test-extension",
outputChannel: "Test Output",
version: "1.0.0",
},
}))

vi.mock("../shared/language", () => ({
formatLanguage: vi.fn().mockReturnValue("en"),
}))

vi.mock("../core/config/ContextProxy", () => ({
ContextProxy: {
getInstance: vi.fn().mockResolvedValue({
getValue: vi.fn(),
setValue: vi.fn(),
getValues: vi.fn().mockReturnValue({}),
getProviderSettings: vi.fn().mockReturnValue({}),
}),
},
}))

vi.mock("../integrations/editor/DiffViewProvider", () => ({
DIFF_VIEW_URI_SCHEME: "test-diff-scheme",
}))

vi.mock("../integrations/terminal/TerminalRegistry", () => ({
TerminalRegistry: {
initialize: vi.fn(),
cleanup: vi.fn(),
},
}))

vi.mock("../services/mcp/McpServerManager", () => ({
McpServerManager: {
cleanup: vi.fn().mockResolvedValue(undefined),
getInstance: vi.fn().mockResolvedValue(null),
unregisterProvider: vi.fn(),
},
}))

vi.mock("../services/code-index/manager", () => ({
CodeIndexManager: {
getInstance: vi.fn().mockReturnValue(null),
},
}))

vi.mock("../services/mdm/MdmService", () => ({
MdmService: {
createInstance: vi.fn().mockResolvedValue(null),
},
}))

vi.mock("../utils/migrateSettings", () => ({
migrateSettings: vi.fn().mockResolvedValue(undefined),
}))

vi.mock("../utils/autoImportSettings", () => ({
autoImportSettings: vi.fn().mockResolvedValue(undefined),
}))

vi.mock("../extension/api", () => ({
API: vi.fn().mockImplementation(() => ({})),
}))

vi.mock("../activate", () => ({
handleUri: vi.fn(),
registerCommands: vi.fn(),
registerCodeActions: vi.fn(),
registerTerminalActions: vi.fn(),
CodeActionProvider: vi.fn().mockImplementation(() => ({
providedCodeActionKinds: [],
})),
}))

vi.mock("../i18n", () => ({
initializeI18n: vi.fn(),
}))

describe("extension.ts", () => {
let mockContext: vscode.ExtensionContext
let authStateChangedHandler:
| ((data: { state: AuthState; previousState: AuthState }) => void | Promise<void>)
| undefined

beforeEach(() => {
vi.clearAllMocks()
mockBridgeOrchestratorDisconnect.mockClear()

mockContext = {
extensionPath: "/test/path",
globalState: {
get: vi.fn().mockReturnValue(undefined),
update: vi.fn(),
},
subscriptions: [],
} as unknown as vscode.ExtensionContext

authStateChangedHandler = undefined
})

test("authStateChangedHandler calls BridgeOrchestrator.disconnect when logged-out event fires", async () => {
const { CloudService, BridgeOrchestrator } = await import("@roo-code/cloud")

// Capture the auth state changed handler.
vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => {
if (handlers?.["auth-state-changed"]) {
authStateChangedHandler = handlers["auth-state-changed"]
}

return {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
} as any
})

// Activate the extension.
const { activate } = await import("../extension")
await activate(mockContext)

// Verify handler was registered.
expect(authStateChangedHandler).toBeDefined()

// Trigger logout.
await authStateChangedHandler!({
state: "logged-out" as AuthState,
previousState: "logged-in" as AuthState,
})

// Verify BridgeOrchestrator.disconnect was called
expect(mockBridgeOrchestratorDisconnect).toHaveBeenCalled()
})

test("authStateChangedHandler does not call BridgeOrchestrator.disconnect for other states", async () => {
const { CloudService } = await import("@roo-code/cloud")

// Capture the auth state changed handler.
vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => {
if (handlers?.["auth-state-changed"]) {
authStateChangedHandler = handlers["auth-state-changed"]
}

return {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
} as any
})

// Activate the extension.
const { activate } = await import("../extension")
await activate(mockContext)

// Trigger login.
await authStateChangedHandler!({
state: "logged-in" as AuthState,
previousState: "logged-out" as AuthState,
})

// Verify BridgeOrchestrator.disconnect was NOT called.
expect(mockBridgeOrchestratorDisconnect).not.toHaveBeenCalled()
})
})
12 changes: 12 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2262,7 +2262,19 @@ export class ClineProvider
}

public async remoteControlEnabled(enabled: boolean) {
if (!enabled) {
await BridgeOrchestrator.disconnect()
return
}

const userInfo = CloudService.instance.getUserInfo()

if (!userInfo) {
this.log("[ClineProvider#remoteControlEnabled] Failed to get user info, disconnecting")
await BridgeOrchestrator.disconnect()
return
}

const config = await CloudService.instance.cloudAPI?.bridgeConfig().catch(() => undefined)

if (!config) {
Expand Down
9 changes: 4 additions & 5 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ export async function activate(context: vscode.ExtensionContext) {
if (data.state === "logged-out") {
try {
await provider.remoteControlEnabled(false)
cloudLogger("[CloudService] BridgeOrchestrator disconnected on logout")
} catch (error) {
cloudLogger(
`[CloudService] Failed to disconnect BridgeOrchestrator on logout: ${error instanceof Error ? error.message : String(error)}`,
`[authStateChangedHandler] remoteControlEnabled(false) failed: ${error instanceof Error ? error.message : String(error)}`,
)
}
}
Expand All @@ -151,7 +150,7 @@ export async function activate(context: vscode.ExtensionContext) {
provider.remoteControlEnabled(CloudService.instance.isTaskSyncEnabled())
} catch (error) {
cloudLogger(
`[CloudService] BridgeOrchestrator#connectOrDisconnect failed on settings change: ${error instanceof Error ? error.message : String(error)}`,
`[settingsUpdatedHandler] remoteControlEnabled failed: ${error instanceof Error ? error.message : String(error)}`,
)
}
}
Expand All @@ -163,15 +162,15 @@ export async function activate(context: vscode.ExtensionContext) {
postStateListener()

if (!CloudService.instance.cloudAPI) {
cloudLogger("[CloudService] CloudAPI is not initialized")
cloudLogger("[userInfoHandler] CloudAPI is not initialized")
return
}

try {
provider.remoteControlEnabled(CloudService.instance.isTaskSyncEnabled())
} catch (error) {
cloudLogger(
`[CloudService] BridgeOrchestrator#connectOrDisconnect failed on user change: ${error instanceof Error ? error.message : String(error)}`,
`[userInfoHandler] remoteControlEnabled failed: ${error instanceof Error ? error.message : String(error)}`,
)
}
}
Expand Down
Loading