Skip to content

Commit c327bad

Browse files
committed
Handle initial session refresh when checking compliance
1 parent e95c1e8 commit c327bad

File tree

6 files changed

+226
-8
lines changed

6 files changed

+226
-8
lines changed

packages/cloud/src/AuthService.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { getUserAgent } from "./utils"
1414
export interface AuthServiceEvents {
1515
"inactive-session": [data: { previousState: AuthState }]
1616
"active-session": [data: { previousState: AuthState }]
17+
"refreshing-session": [data: { previousState: AuthState }]
1718
"logged-out": [data: { previousState: AuthState }]
1819
"user-info": [data: { userInfo: CloudUserInfo }]
1920
}
@@ -28,7 +29,7 @@ type AuthCredentials = z.infer<typeof authCredentialsSchema>
2829
const AUTH_CREDENTIALS_KEY = "clerk-auth-credentials"
2930
const AUTH_STATE_KEY = "clerk-auth-state"
3031

31-
type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-session"
32+
type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-session" | "refreshing-session"
3233

3334
export class AuthService extends EventEmitter<AuthServiceEvents> {
3435
private context: vscode.ExtensionContext
@@ -277,6 +278,10 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
277278
return this.state === "active-session"
278279
}
279280

281+
public isRefreshingSession(): boolean {
282+
return this.state === "refreshing-session"
283+
}
284+
280285
/**
281286
* Refresh the session
282287
*
@@ -291,14 +296,20 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
291296

292297
try {
293298
const previousState = this.state
299+
300+
// Transition to refreshing state
301+
if (this.state !== "refreshing-session") {
302+
this.state = "refreshing-session"
303+
this.emit("refreshing-session", { previousState })
304+
this.log("[auth] Transitioned to refreshing-session state")
305+
}
306+
294307
this.sessionToken = await this.clerkCreateSessionToken()
295308
this.state = "active-session"
296309

297-
if (previousState !== "active-session") {
298-
this.log("[auth] Transitioned to active-session state")
299-
this.emit("active-session", { previousState })
300-
this.fetchUserInfo()
301-
}
310+
this.log("[auth] Transitioned to active-session state")
311+
this.emit("active-session", { previousState: "refreshing-session" })
312+
this.fetchUserInfo()
302313
} catch (error) {
303314
this.log("[auth] Failed to refresh session", error)
304315
throw error

packages/cloud/src/CloudService.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export class CloudService {
4141

4242
this.authService.on("inactive-session", this.authListener)
4343
this.authService.on("active-session", this.authListener)
44+
this.authService.on("refreshing-session", this.authListener)
4445
this.authService.on("logged-out", this.authListener)
4546
this.authService.on("user-info", this.authListener)
4647

@@ -87,6 +88,11 @@ export class CloudService {
8788
return this.authService!.hasActiveSession()
8889
}
8990

91+
public isRefreshingSession(): boolean {
92+
this.ensureInitialized()
93+
return this.authService!.isRefreshingSession()
94+
}
95+
9096
public getUserInfo(): CloudUserInfo | null {
9197
this.ensureInitialized()
9298
return this.authService!.getUserInfo()
@@ -150,7 +156,9 @@ export class CloudService {
150156

151157
public dispose(): void {
152158
if (this.authService) {
159+
this.authService.off("inactive-session", this.authListener)
153160
this.authService.off("active-session", this.authListener)
161+
this.authService.off("refreshing-session", this.authListener)
154162
this.authService.off("logged-out", this.authListener)
155163
this.authService.off("user-info", this.authListener)
156164
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// npx vitest run src/__tests__/AuthService.test.ts
2+
3+
import { describe, it, expect, vi, beforeEach } from "vitest"
4+
import * as vscode from "vscode"
5+
import { AuthService } from "../AuthService"
6+
7+
// Mock vscode
8+
vi.mock("vscode", () => ({
9+
ExtensionContext: vi.fn(),
10+
window: {
11+
showInformationMessage: vi.fn(),
12+
},
13+
env: {
14+
openExternal: vi.fn(),
15+
uriScheme: "vscode",
16+
},
17+
Uri: {
18+
parse: vi.fn(),
19+
},
20+
}))
21+
22+
// Mock axios
23+
vi.mock("axios", () => ({
24+
default: {
25+
post: vi.fn(),
26+
get: vi.fn(),
27+
},
28+
}))
29+
30+
// Mock other dependencies
31+
vi.mock("../Config", () => ({
32+
getClerkBaseUrl: vi.fn(() => "https://clerk.test"),
33+
getRooCodeApiUrl: vi.fn(() => "https://api.test"),
34+
}))
35+
36+
vi.mock("../RefreshTimer", () => ({
37+
RefreshTimer: vi.fn().mockImplementation(() => ({
38+
start: vi.fn(),
39+
stop: vi.fn(),
40+
})),
41+
}))
42+
43+
vi.mock("../utils", () => ({
44+
getUserAgent: vi.fn(() => "test-agent"),
45+
}))
46+
47+
describe("AuthService", () => {
48+
let mockContext: Partial<vscode.ExtensionContext>
49+
let authService: AuthService
50+
51+
beforeEach(() => {
52+
vi.clearAllMocks()
53+
54+
mockContext = {
55+
secrets: {
56+
store: vi.fn(),
57+
get: vi.fn(),
58+
delete: vi.fn(),
59+
onDidChange: vi.fn(() => ({ dispose: vi.fn() })),
60+
} as Partial<vscode.SecretStorage> as vscode.SecretStorage,
61+
globalState: {
62+
update: vi.fn(),
63+
get: vi.fn(),
64+
keys: vi.fn(() => []),
65+
setKeysForSync: vi.fn(),
66+
} as Partial<vscode.Memento & { setKeysForSync(keys: readonly string[]): void }> as vscode.Memento & {
67+
setKeysForSync(keys: readonly string[]): void
68+
},
69+
subscriptions: [],
70+
extension: {
71+
packageJSON: {
72+
publisher: "test",
73+
name: "test-extension",
74+
},
75+
} as Partial<vscode.Extension<unknown>> as vscode.Extension<unknown>,
76+
}
77+
78+
authService = new AuthService(mockContext as vscode.ExtensionContext)
79+
})
80+
81+
describe("State Management", () => {
82+
it("should initialize with 'initializing' state", () => {
83+
expect(authService.getState()).toBe("initializing")
84+
})
85+
86+
it("should have isRefreshingSession method that returns false initially", () => {
87+
expect(authService.isRefreshingSession()).toBe(false)
88+
})
89+
90+
it("should include refreshing-session in AuthState type", () => {
91+
// This test verifies that the new state is properly typed
92+
// by checking that the method exists and returns a boolean
93+
expect(typeof authService.isRefreshingSession).toBe("function")
94+
expect(typeof authService.isRefreshingSession()).toBe("boolean")
95+
})
96+
})
97+
98+
describe("Event Emission", () => {
99+
it("should emit refreshing-session event when transitioning to refreshing state", async () => {
100+
// Set up the auth service to have credentials
101+
const mockCredentials = {
102+
clientToken: "test-token",
103+
sessionId: "test-session",
104+
}
105+
106+
// Mock the secrets.get to return credentials
107+
vi.mocked(mockContext.secrets!.get).mockResolvedValue(JSON.stringify(mockCredentials))
108+
109+
// Create a promise to wait for the event
110+
const eventPromise = new Promise((resolve) => {
111+
authService.on("refreshing-session", (data) => {
112+
expect(data).toHaveProperty("previousState")
113+
resolve(data)
114+
})
115+
})
116+
117+
// This would trigger the refresh process in a real scenario
118+
// For this test, we're just verifying the event structure exists
119+
// We can manually emit the event to test the interface
120+
authService.emit("refreshing-session", { previousState: "inactive-session" })
121+
122+
await eventPromise
123+
})
124+
})
125+
})

packages/cloud/src/__tests__/CloudService.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ describe("CloudService", () => {
3636
logout: ReturnType<typeof vi.fn>
3737
isAuthenticated: ReturnType<typeof vi.fn>
3838
hasActiveSession: ReturnType<typeof vi.fn>
39+
isRefreshingSession: ReturnType<typeof vi.fn>
3940
getUserInfo: ReturnType<typeof vi.fn>
4041
getState: ReturnType<typeof vi.fn>
4142
getSessionToken: ReturnType<typeof vi.fn>
@@ -84,6 +85,7 @@ describe("CloudService", () => {
8485
logout: vi.fn(),
8586
isAuthenticated: vi.fn().mockReturnValue(false),
8687
hasActiveSession: vi.fn().mockReturnValue(false),
88+
isRefreshingSession: vi.fn().mockReturnValue(false),
8789
getUserInfo: vi.fn(),
8890
getState: vi.fn().mockReturnValue("logged-out"),
8991
getSessionToken: vi.fn(),
@@ -179,6 +181,12 @@ describe("CloudService", () => {
179181
expect(result).toBe(false)
180182
})
181183

184+
it("should delegate isRefreshingSession to AuthService", () => {
185+
const result = cloudService.isRefreshingSession()
186+
expect(mockAuthService.isRefreshingSession).toHaveBeenCalled()
187+
expect(result).toBe(false)
188+
})
189+
182190
it("should delegate getUserInfo to AuthService", async () => {
183191
await cloudService.getUserInfo()
184192
expect(mockAuthService.getUserInfo).toHaveBeenCalled()

src/services/mdm/MdmService.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ export class MdmService {
8585
return { compliant: true }
8686
}
8787

88-
// Check if cloud service is available and authenticated
89-
if (!CloudService.hasInstance() || !CloudService.instance.hasActiveSession()) {
88+
const cloudService = CloudService.instance
89+
const hasActiveSession = cloudService?.hasActiveSession()
90+
const isRefreshingSession = cloudService?.isRefreshingSession()
91+
92+
// Allow only if user has active session or is refreshing session
93+
if (!hasActiveSession && !isRefreshingSession) {
9094
return {
9195
compliant: false,
9296
reason: "Your organization requires Roo Code Cloud authentication. Please sign in to continue.",

src/services/mdm/__tests__/MdmService.spec.ts

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ vi.mock("@roo-code/cloud", () => ({
1616
hasInstance: vi.fn(),
1717
instance: {
1818
hasActiveSession: vi.fn(),
19+
isRefreshingSession: vi.fn(),
20+
isAuthenticated: vi.fn(),
1921
getOrganizationId: vi.fn(),
2022
},
2123
},
@@ -244,6 +246,8 @@ describe("MdmService", () => {
244246

245247
mockCloudService.hasInstance.mockReturnValue(true)
246248
mockCloudService.instance.hasActiveSession.mockReturnValue(true)
249+
mockCloudService.instance.isRefreshingSession.mockReturnValue(false)
250+
mockCloudService.instance.isAuthenticated.mockReturnValue(true)
247251

248252
const service = await MdmService.createInstance()
249253
const compliance = service.isCompliant()
@@ -279,6 +283,8 @@ describe("MdmService", () => {
279283
// Mock CloudService to have instance and active session but wrong org
280284
mockCloudService.hasInstance.mockReturnValue(true)
281285
mockCloudService.instance.hasActiveSession.mockReturnValue(true)
286+
mockCloudService.instance.isRefreshingSession.mockReturnValue(false)
287+
mockCloudService.instance.isAuthenticated.mockReturnValue(true)
282288
mockCloudService.instance.getOrganizationId.mockReturnValue("different-org-456")
283289

284290
const service = await MdmService.createInstance()
@@ -300,13 +306,69 @@ describe("MdmService", () => {
300306

301307
mockCloudService.hasInstance.mockReturnValue(true)
302308
mockCloudService.instance.hasActiveSession.mockReturnValue(true)
309+
mockCloudService.instance.isRefreshingSession.mockReturnValue(false)
310+
mockCloudService.instance.isAuthenticated.mockReturnValue(true)
303311
mockCloudService.instance.getOrganizationId.mockReturnValue("correct-org-123")
304312

305313
const service = await MdmService.createInstance()
306314
const compliance = service.isCompliant()
307315

308316
expect(compliance.compliant).toBe(true)
309317
})
318+
319+
it("should be compliant when refreshing session", async () => {
320+
const mockConfig = { requireCloudAuth: true }
321+
mockFs.existsSync.mockReturnValue(true)
322+
mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
323+
324+
mockCloudService.hasInstance.mockReturnValue(true)
325+
mockCloudService.instance.hasActiveSession.mockReturnValue(false)
326+
mockCloudService.instance.isRefreshingSession.mockReturnValue(true)
327+
mockCloudService.instance.isAuthenticated.mockReturnValue(true)
328+
329+
const service = await MdmService.createInstance()
330+
const compliance = service.isCompliant()
331+
332+
expect(compliance.compliant).toBe(true)
333+
})
334+
335+
it("should be non-compliant when authenticated but not active or refreshing", async () => {
336+
const mockConfig = { requireCloudAuth: true }
337+
mockFs.existsSync.mockReturnValue(true)
338+
mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
339+
340+
mockCloudService.hasInstance.mockReturnValue(true)
341+
mockCloudService.instance.hasActiveSession.mockReturnValue(false)
342+
mockCloudService.instance.isRefreshingSession.mockReturnValue(false)
343+
mockCloudService.instance.isAuthenticated.mockReturnValue(true)
344+
345+
const service = await MdmService.createInstance()
346+
const compliance = service.isCompliant()
347+
348+
expect(compliance.compliant).toBe(false)
349+
if (!compliance.compliant) {
350+
expect(compliance.reason).toContain("requires Roo Code Cloud authentication")
351+
}
352+
})
353+
354+
it("should be non-compliant when not authenticated, not active, and not refreshing", async () => {
355+
const mockConfig = { requireCloudAuth: true }
356+
mockFs.existsSync.mockReturnValue(true)
357+
mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
358+
359+
mockCloudService.hasInstance.mockReturnValue(true)
360+
mockCloudService.instance.hasActiveSession.mockReturnValue(false)
361+
mockCloudService.instance.isRefreshingSession.mockReturnValue(false)
362+
mockCloudService.instance.isAuthenticated.mockReturnValue(false)
363+
364+
const service = await MdmService.createInstance()
365+
const compliance = service.isCompliant()
366+
367+
expect(compliance.compliant).toBe(false)
368+
if (!compliance.compliant) {
369+
expect(compliance.reason).toContain("requires Roo Code Cloud authentication")
370+
}
371+
})
310372
})
311373

312374
describe("cloud enablement", () => {

0 commit comments

Comments
 (0)