Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions packages/cloud/src/AuthService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { z } from "zod"

import type { CloudUserInfo, CloudOrganizationMembership } from "@roo-code/types"

import { getClerkBaseUrl, getRooCodeApiUrl } from "./Config"
import { getClerkBaseUrl, getRooCodeApiUrl, PRODUCTION_CLERK_BASE_URL } from "./Config"
import { RefreshTimer } from "./RefreshTimer"
import { getUserAgent } from "./utils"

Expand All @@ -24,7 +24,6 @@ const authCredentialsSchema = z.object({

type AuthCredentials = z.infer<typeof authCredentialsSchema>

const AUTH_CREDENTIALS_KEY = "clerk-auth-credentials"
const AUTH_STATE_KEY = "clerk-auth-state"

type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-session"
Expand Down Expand Up @@ -89,6 +88,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
private timer: RefreshTimer
private state: AuthState = "initializing"
private log: (...args: unknown[]) => void
private readonly authCredentialsKey: string

private credentials: AuthCredentials | null = null
private sessionToken: string | null = null
Expand All @@ -100,6 +100,14 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
this.context = context
this.log = log || console.log

// Calculate auth credentials key based on Clerk base URL
const clerkBaseUrl = getClerkBaseUrl()
if (clerkBaseUrl !== PRODUCTION_CLERK_BASE_URL) {
this.authCredentialsKey = `clerk-auth-credentials-${clerkBaseUrl}`
} else {
this.authCredentialsKey = "clerk-auth-credentials"
}

this.timer = new RefreshTimer({
callback: async () => {
await this.refreshSession()
Expand Down Expand Up @@ -180,19 +188,19 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {

this.context.subscriptions.push(
this.context.secrets.onDidChange((e) => {
if (e.key === AUTH_CREDENTIALS_KEY) {
if (e.key === this.authCredentialsKey) {
this.handleCredentialsChange()
}
}),
)
}

private async storeCredentials(credentials: AuthCredentials): Promise<void> {
await this.context.secrets.store(AUTH_CREDENTIALS_KEY, JSON.stringify(credentials))
await this.context.secrets.store(this.authCredentialsKey, JSON.stringify(credentials))
}

private async loadCredentials(): Promise<AuthCredentials | null> {
const credentialsJson = await this.context.secrets.get(AUTH_CREDENTIALS_KEY)
const credentialsJson = await this.context.secrets.get(this.authCredentialsKey)
if (!credentialsJson) return null

try {
Expand All @@ -209,7 +217,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
}

private async clearCredentials(): Promise<void> {
await this.context.secrets.delete(AUTH_CREDENTIALS_KEY)
await this.context.secrets.delete(this.authCredentialsKey)
}

/**
Expand Down
143 changes: 140 additions & 3 deletions packages/cloud/src/__tests__/AuthService.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ describe("AuthService", () => {
}
vi.mocked(RefreshTimer).mockImplementation(() => mockTimer as unknown as RefreshTimer)

// Setup config mocks
vi.mocked(Config.getClerkBaseUrl).mockReturnValue("https://clerk.test.com")
// Setup config mocks - use production URL by default to maintain existing test behavior
vi.mocked(Config.getClerkBaseUrl).mockReturnValue("https://clerk.roocode.com")
vi.mocked(Config.getRooCodeApiUrl).mockReturnValue("https://api.test.com")

// Setup utils mock
Expand Down Expand Up @@ -377,7 +377,7 @@ describe("AuthService", () => {
expect(mockContext.secrets.delete).toHaveBeenCalledWith("clerk-auth-credentials")
expect(mockContext.globalState.update).toHaveBeenCalledWith("clerk-auth-state", undefined)
expect(mockFetch).toHaveBeenCalledWith(
"https://clerk.test.com/v1/client/sessions/test-session/remove",
"https://clerk.roocode.com/v1/client/sessions/test-session/remove",
expect.objectContaining({
method: "POST",
headers: expect.objectContaining({
Expand Down Expand Up @@ -812,4 +812,141 @@ describe("AuthService", () => {
expect(mockTimer.start).toHaveBeenCalled()
})
})

describe("auth credentials key scoping", () => {
it("should use default key when getClerkBaseUrl returns production URL", async () => {
// Mock getClerkBaseUrl to return production URL
vi.mocked(Config.getClerkBaseUrl).mockReturnValue("https://clerk.roocode.com")

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
const credentials = { clientToken: "test-token", sessionId: "test-session" }

await service.initialize()
await service["storeCredentials"](credentials)

expect(mockContext.secrets.store).toHaveBeenCalledWith(
"clerk-auth-credentials",
JSON.stringify(credentials),
)
})

it("should use scoped key when getClerkBaseUrl returns custom URL", async () => {
const customUrl = "https://custom.clerk.com"
// Mock getClerkBaseUrl to return custom URL
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
const credentials = { clientToken: "test-token", sessionId: "test-session" }

await service.initialize()
await service["storeCredentials"](credentials)

expect(mockContext.secrets.store).toHaveBeenCalledWith(
`clerk-auth-credentials-${customUrl}`,
JSON.stringify(credentials),
)
})

it("should load credentials using scoped key", async () => {
const customUrl = "https://custom.clerk.com"
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
const credentials = { clientToken: "test-token", sessionId: "test-session" }
mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))

await service.initialize()
const loadedCredentials = await service["loadCredentials"]()

expect(mockContext.secrets.get).toHaveBeenCalledWith(`clerk-auth-credentials-${customUrl}`)
expect(loadedCredentials).toEqual(credentials)
})

it("should clear credentials using scoped key", async () => {
const customUrl = "https://custom.clerk.com"
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)

await service.initialize()
await service["clearCredentials"]()

expect(mockContext.secrets.delete).toHaveBeenCalledWith(`clerk-auth-credentials-${customUrl}`)
})

it("should listen for changes on scoped key", async () => {
const customUrl = "https://custom.clerk.com"
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

let onDidChangeCallback: (e: { key: string }) => void

mockContext.secrets.onDidChange.mockImplementation((callback: (e: { key: string }) => void) => {
onDidChangeCallback = callback
return { dispose: vi.fn() }
})

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
await service.initialize()

// Simulate credentials change event with scoped key
const newCredentials = { clientToken: "new-token", sessionId: "new-session" }
mockContext.secrets.get.mockResolvedValue(JSON.stringify(newCredentials))

const inactiveSessionSpy = vi.fn()
service.on("inactive-session", inactiveSessionSpy)

onDidChangeCallback!({ key: `clerk-auth-credentials-${customUrl}` })
await new Promise((resolve) => setTimeout(resolve, 0)) // Wait for async handling

expect(inactiveSessionSpy).toHaveBeenCalled()
})

it("should not respond to changes on different scoped keys", async () => {
const customUrl = "https://custom.clerk.com"
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

let onDidChangeCallback: (e: { key: string }) => void

mockContext.secrets.onDidChange.mockImplementation((callback: (e: { key: string }) => void) => {
onDidChangeCallback = callback
return { dispose: vi.fn() }
})

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
await service.initialize()

const inactiveSessionSpy = vi.fn()
service.on("inactive-session", inactiveSessionSpy)

// Simulate credentials change event with different scoped key
onDidChangeCallback!({ key: "clerk-auth-credentials-https://other.clerk.com" })
await new Promise((resolve) => setTimeout(resolve, 0)) // Wait for async handling

expect(inactiveSessionSpy).not.toHaveBeenCalled()
})

it("should not respond to changes on default key when using scoped key", async () => {
const customUrl = "https://custom.clerk.com"
vi.mocked(Config.getClerkBaseUrl).mockReturnValue(customUrl)

let onDidChangeCallback: (e: { key: string }) => void

mockContext.secrets.onDidChange.mockImplementation((callback: (e: { key: string }) => void) => {
onDidChangeCallback = callback
return { dispose: vi.fn() }
})

const service = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
await service.initialize()

const inactiveSessionSpy = vi.fn()
service.on("inactive-session", inactiveSessionSpy)

// Simulate credentials change event with default key
onDidChangeCallback!({ key: "clerk-auth-credentials" })
await new Promise((resolve) => setTimeout(resolve, 0)) // Wait for async handling

expect(inactiveSessionSpy).not.toHaveBeenCalled()
})
})
})
Loading