|
1 | 1 | import OAuthProvider from '@cloudflare/workers-oauth-provider' |
2 | 2 | import { McpAgent } from 'agents/mcp' |
3 | | -import { env } from 'cloudflare:workers' |
4 | 3 |
|
5 | 4 | import { |
6 | 5 | createAuthHandlers, |
7 | 6 | handleTokenExchangeCallback, |
8 | 7 | } from '@repo/mcp-common/src/cloudflare-oauth-handler' |
| 8 | +import { getUserDetails, UserDetails } from '@repo/mcp-common/src/durable-objects/user_details' |
| 9 | +import { getEnv } from '@repo/mcp-common/src/env' |
9 | 10 | import { CloudflareMCPServer } from '@repo/mcp-common/src/server' |
10 | 11 | import { registerAccountTools } from '@repo/mcp-common/src/tools/account' |
11 | 12 |
|
12 | 13 | import { MetricsTracker } from '../../../packages/mcp-observability/src' |
13 | 14 | import { registerIntegrationsTools } from './tools/integrations' |
14 | 15 |
|
15 | 16 | import type { AccountSchema, UserSchema } from '@repo/mcp-common/src/cloudflare-oauth-handler' |
16 | | -import type { CloudflareMcpAgent } from '@repo/mcp-common/src/types/cloudflare-mcp-agent' |
| 17 | +import type { Env } from './context' |
| 18 | + |
| 19 | +export { UserDetails } |
| 20 | + |
| 21 | +const env = getEnv<Env>() |
17 | 22 |
|
18 | 23 | const metrics = new MetricsTracker(env.MCP_METRICS, { |
19 | 24 | name: env.MCP_SERVER_NAME, |
@@ -46,32 +51,37 @@ export class CASBMCP extends McpAgent<Env, State, Props> { |
46 | 51 | } |
47 | 52 |
|
48 | 53 | async init() { |
49 | | - this.server = new CloudflareMCPServer(this.props.user.id, this.env.MCP_METRICS, { |
50 | | - name: this.env.MCP_SERVER_NAME, |
51 | | - version: this.env.MCP_SERVER_VERSION, |
| 54 | + this.server = new CloudflareMCPServer({ |
| 55 | + userId: this.props.user.id, |
| 56 | + wae: this.env.MCP_METRICS, |
| 57 | + serverInfo: { |
| 58 | + name: this.env.MCP_SERVER_NAME, |
| 59 | + version: this.env.MCP_SERVER_VERSION, |
| 60 | + }, |
52 | 61 | }) |
53 | 62 |
|
54 | 63 | registerAccountTools(this) |
55 | 64 | registerIntegrationsTools(this) |
56 | 65 | } |
57 | 66 |
|
58 | | - getActiveAccountId() { |
| 67 | + async getActiveAccountId() { |
59 | 68 | try { |
60 | | - return this.state.activeAccountId ?? null |
| 69 | + // Get UserDetails Durable Object based off the userId and retrieve the activeAccountId from it |
| 70 | + // we do this so we can persist activeAccountId across sessions |
| 71 | + const userDetails = getUserDetails(env, this.props.user.id) |
| 72 | + return await userDetails.getActiveAccountId() |
61 | 73 | } catch (e) { |
62 | | - console.error('getActiveAccountId failured: ', e) |
| 74 | + this.server.recordError(e) |
63 | 75 | return null |
64 | 76 | } |
65 | 77 | } |
66 | 78 |
|
67 | | - setActiveAccountId(accountId: string) { |
| 79 | + async setActiveAccountId(accountId: string) { |
68 | 80 | try { |
69 | | - this.setState({ |
70 | | - ...this.state, |
71 | | - activeAccountId: accountId, |
72 | | - }) |
| 81 | + const userDetails = getUserDetails(env, this.props.user.id) |
| 82 | + await userDetails.setActiveAccountId(accountId) |
73 | 83 | } catch (e) { |
74 | | - return null |
| 84 | + this.server.recordError(e) |
75 | 85 | } |
76 | 86 | } |
77 | 87 | } |
|
0 commit comments