Skip to content

Commit 9a6b73e

Browse files
committed
Merge branch 'main' into cte/move-evals
2 parents b979278 + 927e210 commit 9a6b73e

File tree

17 files changed

+311
-53
lines changed

17 files changed

+311
-53
lines changed

packages/cloud/src/AuthService.ts

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { getClerkBaseUrl, getRooCodeApiUrl } from "./Config"
1111
import { RefreshTimer } from "./RefreshTimer"
1212

1313
export interface AuthServiceEvents {
14+
"inactive-session": [data: { previousState: AuthState }]
1415
"active-session": [data: { previousState: AuthState }]
1516
"logged-out": [data: { previousState: AuthState }]
1617
"user-info": [data: { userInfo: CloudUserInfo }]
@@ -32,15 +33,17 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
3233
private context: vscode.ExtensionContext
3334
private timer: RefreshTimer
3435
private state: AuthState = "initializing"
36+
private log: (...args: unknown[]) => void
3537

3638
private credentials: AuthCredentials | null = null
3739
private sessionToken: string | null = null
3840
private userInfo: CloudUserInfo | null = null
3941

40-
constructor(context: vscode.ExtensionContext) {
42+
constructor(context: vscode.ExtensionContext, log?: (...args: unknown[]) => void) {
4143
super()
4244

4345
this.context = context
46+
this.log = log || console.log
4447

4548
this.timer = new RefreshTimer({
4649
callback: async () => {
@@ -71,7 +74,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
7174
}
7275
}
7376
} catch (error) {
74-
console.error("[auth] Error handling credentials change:", error)
77+
this.log("[auth] Error handling credentials change:", error)
7578
}
7679
}
7780

@@ -87,19 +90,23 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
8790

8891
this.emit("logged-out", { previousState })
8992

90-
console.log("[auth] Transitioned to logged-out state")
93+
this.log("[auth] Transitioned to logged-out state")
9194
}
9295

9396
private transitionToInactiveSession(credentials: AuthCredentials): void {
9497
this.credentials = credentials
98+
99+
const previousState = this.state
95100
this.state = "inactive-session"
96101

97102
this.sessionToken = null
98103
this.userInfo = null
99104

105+
this.emit("inactive-session", { previousState })
106+
100107
this.timer.start()
101108

102-
console.log("[auth] Transitioned to inactive-session state")
109+
this.log("[auth] Transitioned to inactive-session state")
103110
}
104111

105112
/**
@@ -110,7 +117,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
110117
*/
111118
public async initialize(): Promise<void> {
112119
if (this.state !== "initializing") {
113-
console.log("[auth] initialize() called after already initialized")
120+
this.log("[auth] initialize() called after already initialized")
114121
return
115122
}
116123

@@ -138,9 +145,9 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
138145
return authCredentialsSchema.parse(parsedJson)
139146
} catch (error) {
140147
if (error instanceof z.ZodError) {
141-
console.error("[auth] Invalid credentials format:", error.errors)
148+
this.log("[auth] Invalid credentials format:", error.errors)
142149
} else {
143-
console.error("[auth] Failed to parse stored credentials:", error)
150+
this.log("[auth] Failed to parse stored credentials:", error)
144151
}
145152
return null
146153
}
@@ -171,7 +178,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
171178
const url = `${getRooCodeApiUrl()}/extension/sign-in?${params.toString()}`
172179
await vscode.env.openExternal(vscode.Uri.parse(url))
173180
} catch (error) {
174-
console.error(`[auth] Error initiating Roo Code Cloud auth: ${error}`)
181+
this.log(`[auth] Error initiating Roo Code Cloud auth: ${error}`)
175182
throw new Error(`Failed to initiate Roo Code Cloud authentication: ${error}`)
176183
}
177184
}
@@ -196,7 +203,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
196203
const storedState = this.context.globalState.get(AUTH_STATE_KEY)
197204

198205
if (state !== storedState) {
199-
console.log("[auth] State mismatch in callback")
206+
this.log("[auth] State mismatch in callback")
200207
throw new Error("Invalid state parameter. Authentication request may have been tampered with.")
201208
}
202209

@@ -205,9 +212,9 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
205212
await this.storeCredentials(credentials)
206213

207214
vscode.window.showInformationMessage("Successfully authenticated with Roo Code Cloud")
208-
console.log("[auth] Successfully authenticated with Roo Code Cloud")
215+
this.log("[auth] Successfully authenticated with Roo Code Cloud")
209216
} catch (error) {
210-
console.log(`[auth] Error handling Roo Code Cloud callback: ${error}`)
217+
this.log(`[auth] Error handling Roo Code Cloud callback: ${error}`)
211218
const previousState = this.state
212219
this.state = "logged-out"
213220
this.emit("logged-out", { previousState })
@@ -232,14 +239,14 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
232239
try {
233240
await this.clerkLogout(oldCredentials)
234241
} catch (error) {
235-
console.error("[auth] Error calling clerkLogout:", error)
242+
this.log("[auth] Error calling clerkLogout:", error)
236243
}
237244
}
238245

239246
vscode.window.showInformationMessage("Logged out from Roo Code Cloud")
240-
console.log("[auth] Logged out from Roo Code Cloud")
247+
this.log("[auth] Logged out from Roo Code Cloud")
241248
} catch (error) {
242-
console.log(`[auth] Error logging out from Roo Code Cloud: ${error}`)
249+
this.log(`[auth] Error logging out from Roo Code Cloud: ${error}`)
243250
throw new Error(`Failed to log out from Roo Code Cloud: ${error}`)
244251
}
245252
}
@@ -276,19 +283,24 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
276283
*/
277284
private async refreshSession(): Promise<void> {
278285
if (!this.credentials) {
279-
console.log("[auth] Cannot refresh session: missing credentials")
286+
this.log("[auth] Cannot refresh session: missing credentials")
280287
this.state = "inactive-session"
281288
return
282289
}
283290

284-
const previousState = this.state
285-
this.sessionToken = await this.clerkCreateSessionToken()
286-
this.state = "active-session"
291+
try {
292+
const previousState = this.state
293+
this.sessionToken = await this.clerkCreateSessionToken()
294+
this.state = "active-session"
287295

288-
if (previousState !== "active-session") {
289-
console.log("[auth] Transitioned to active-session state")
290-
this.emit("active-session", { previousState })
291-
this.fetchUserInfo()
296+
if (previousState !== "active-session") {
297+
this.log("[auth] Transitioned to active-session state")
298+
this.emit("active-session", { previousState })
299+
this.fetchUserInfo()
300+
}
301+
} catch (error) {
302+
this.log("[auth] Failed to refresh session", error)
303+
throw error
292304
}
293305
}
294306

@@ -436,12 +448,12 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
436448
return this._instance
437449
}
438450

439-
static async createInstance(context: vscode.ExtensionContext) {
451+
static async createInstance(context: vscode.ExtensionContext, log?: (...args: unknown[]) => void) {
440452
if (this._instance) {
441453
throw new Error("AuthService instance already created")
442454
}
443455

444-
this._instance = new AuthService(context)
456+
this._instance = new AuthService(context, log)
445457
await this._instance.initialize()
446458
return this._instance
447459
}

packages/cloud/src/CloudService.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ export class CloudService {
1818
private settingsService: SettingsService | null = null
1919
private telemetryClient: TelemetryClient | null = null
2020
private isInitialized = false
21+
private log: (...args: unknown[]) => void
2122

2223
private constructor(context: vscode.ExtensionContext, callbacks: CloudServiceCallbacks) {
2324
this.context = context
2425
this.callbacks = callbacks
26+
this.log = callbacks.log || console.log
2527
this.authListener = () => {
2628
this.callbacks.stateChanged?.()
2729
}
@@ -33,8 +35,9 @@ export class CloudService {
3335
}
3436

3537
try {
36-
this.authService = await AuthService.createInstance(this.context)
38+
this.authService = await AuthService.createInstance(this.context, this.log)
3739

40+
this.authService.on("inactive-session", this.authListener)
3841
this.authService.on("active-session", this.authListener)
3942
this.authService.on("logged-out", this.authListener)
4043
this.authService.on("user-info", this.authListener)
@@ -48,12 +51,12 @@ export class CloudService {
4851
try {
4952
TelemetryService.instance.register(this.telemetryClient)
5053
} catch (error) {
51-
console.warn("[CloudService] Failed to register TelemetryClient:", error)
54+
this.log("[CloudService] Failed to register TelemetryClient:", error)
5255
}
5356

5457
this.isInitialized = true
5558
} catch (error) {
56-
console.error("[CloudService] Failed to initialize:", error)
59+
this.log("[CloudService] Failed to initialize:", error)
5760
throw new Error(`Failed to initialize CloudService: ${error}`)
5861
}
5962
}

packages/cloud/src/__tests__/CloudService.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ describe("CloudService", () => {
135135
const cloudService = await CloudService.createInstance(mockContext, callbacks)
136136

137137
expect(cloudService).toBeInstanceOf(CloudService)
138-
expect(AuthService.createInstance).toHaveBeenCalledWith(mockContext)
138+
expect(AuthService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function))
139139
expect(SettingsService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function))
140140
})
141141

packages/cloud/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
export interface CloudServiceCallbacks {
22
stateChanged?: () => void
3+
log?: (...args: unknown[]) => void
34
}

packages/types/src/provider-settings.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
100100
awsProfile: z.string().optional(),
101101
awsUseProfile: z.boolean().optional(),
102102
awsCustomArn: z.string().optional(),
103+
awsModelContextWindow: z.number().optional(),
103104
awsBedrockEndpointEnabled: z.boolean().optional(),
104105
awsBedrockEndpoint: z.string().optional(),
105106
})
@@ -285,6 +286,7 @@ export const PROVIDER_SETTINGS_KEYS = keysOf<ProviderSettings>()([
285286
"awsProfile",
286287
"awsUseProfile",
287288
"awsCustomArn",
289+
"awsModelContextWindow",
288290
"awsBedrockEndpointEnabled",
289291
"awsBedrockEndpoint",
290292
// Google Vertex

packages/types/src/providers/bedrock.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ export const BEDROCK_DEFAULT_TEMPERATURE = 0.3
355355

356356
export const BEDROCK_MAX_TOKENS = 4096
357357

358+
export const BEDROCK_DEFAULT_CONTEXT = 128_000
359+
358360
export const BEDROCK_REGION_INFO: Record<
359361
string,
360362
{

src/api/providers/bedrock.ts

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
bedrockDefaultPromptRouterModelId,
2020
BEDROCK_DEFAULT_TEMPERATURE,
2121
BEDROCK_MAX_TOKENS,
22+
BEDROCK_DEFAULT_CONTEXT,
2223
BEDROCK_REGION_INFO,
2324
} from "@roo-code/types"
2425

@@ -192,6 +193,65 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
192193
this.client = new BedrockRuntimeClient(clientConfig)
193194
}
194195

196+
// Helper to guess model info from custom modelId string if not in bedrockModels
197+
private guessModelInfoFromId(modelId: string): Partial<ModelInfo> {
198+
// Define a mapping for model ID patterns and their configurations
199+
const modelConfigMap: Record<string, Partial<ModelInfo>> = {
200+
"claude-4": {
201+
maxTokens: 8192,
202+
contextWindow: 200_000,
203+
supportsImages: true,
204+
supportsPromptCache: true,
205+
},
206+
"claude-3-7": {
207+
maxTokens: 8192,
208+
contextWindow: 200_000,
209+
supportsImages: true,
210+
supportsPromptCache: true,
211+
},
212+
"claude-3-5": {
213+
maxTokens: 8192,
214+
contextWindow: 200_000,
215+
supportsImages: true,
216+
supportsPromptCache: true,
217+
},
218+
"claude-4-opus": {
219+
maxTokens: 4096,
220+
contextWindow: 200_000,
221+
supportsImages: true,
222+
supportsPromptCache: true,
223+
},
224+
"claude-3-opus": {
225+
maxTokens: 4096,
226+
contextWindow: 200_000,
227+
supportsImages: true,
228+
supportsPromptCache: true,
229+
},
230+
"claude-3-haiku": {
231+
maxTokens: 4096,
232+
contextWindow: 200_000,
233+
supportsImages: true,
234+
supportsPromptCache: true,
235+
},
236+
}
237+
238+
// Match the model ID to a configuration
239+
const id = modelId.toLowerCase()
240+
for (const [pattern, config] of Object.entries(modelConfigMap)) {
241+
if (id.includes(pattern)) {
242+
return config
243+
}
244+
}
245+
246+
// Default fallback
247+
return {
248+
maxTokens: BEDROCK_MAX_TOKENS,
249+
contextWindow: BEDROCK_DEFAULT_CONTEXT,
250+
supportsImages: false,
251+
supportsPromptCache: false,
252+
}
253+
}
254+
195255
override async *createMessage(
196256
systemPrompt: string,
197257
messages: Anthropic.Messages.MessageParam[],
@@ -640,16 +700,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
640700
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
641701
}
642702
} else {
703+
// Use heuristics for model info, then allow overrides from ProviderSettings
704+
const guessed = this.guessModelInfoFromId(modelId)
643705
model = {
644706
id: bedrockDefaultModelId,
645-
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
707+
info: {
708+
...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
709+
...guessed,
710+
},
646711
}
647712
}
648713

649-
// If modelMaxTokens is explicitly set in options, override the default
714+
// Always allow user to override detected/guessed maxTokens and contextWindow
650715
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
651716
model.info.maxTokens = this.options.modelMaxTokens
652717
}
718+
if (this.options.awsModelContextWindow && this.options.awsModelContextWindow > 0) {
719+
model.info.contextWindow = this.options.awsModelContextWindow
720+
}
653721

654722
return model
655723
}
@@ -684,8 +752,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
684752
}
685753
}
686754

687-
modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS
688-
755+
// Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides)
689756
return modelConfig as { id: BedrockModelId | string; info: ModelInfo }
690757
}
691758

0 commit comments

Comments
 (0)