Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/mcp-common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"test:coverage": "run-vitest-coverage"
},
"dependencies": {
"@cloudflare/workers-oauth-provider": "0.0.5",
"@cloudflare/workers-oauth-provider": "0.0.12",
"@fast-csv/format": "5.0.2",
"@hono/zod-validator": "0.4.3",
"@modelcontextprotocol/sdk": "1.18.2",
Expand Down
4 changes: 2 additions & 2 deletions packages/mcp-common/src/cloudflare-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export async function getAuthorizationURL({
}: {
client_id: string
redirect_uri: string
state: AuthRequest
state: string
scopes: Record<string, string>
}): Promise<{ authUrl: string; codeVerifier: string }> {
const { codeChallenge, codeVerifier } = await generatePKCECodes()
Expand All @@ -92,7 +92,7 @@ export async function getAuthorizationURL({
authUrl: generateAuthUrl({
client_id,
redirect_uri,
state: btoa(JSON.stringify({ ...state, codeVerifier })),
state,
code_challenge: codeChallenge,
scopes,
}),
Expand Down
69 changes: 44 additions & 25 deletions packages/mcp-common/src/cloudflare-oauth-handler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { zValidator } from '@hono/zod-validator'
import { Hono } from 'hono'
import { deleteCookie, getCookie, setCookie } from 'hono/cookie'
import { z } from 'zod'

import { AuthUser } from '../../mcp-observability/src'
Expand All @@ -9,6 +10,7 @@ import { useSentry } from './sentry'
import { V4Schema } from './v4-api'

import type {
AuthRequest,
OAuthHelpers,
TokenExchangeCallbackOptions,
TokenExchangeCallbackResult,
Expand All @@ -25,21 +27,6 @@ type AuthContext = {
}
} & BaseHonoContext

const AuthRequestSchema = z.object({
responseType: z.string(),
clientId: z.string(),
redirectUri: z.string(),
scope: z.array(z.string()),
state: z.string(),
codeChallenge: z.string().optional(),
codeChallengeMethod: z.string().optional(),
})

// AuthRequest but with extra params that we use in our authentication logic
const AuthRequestSchemaWithExtraParams = AuthRequestSchema.merge(
z.object({ codeVerifier: z.string() })
)

const AuthQuery = z.object({
code: z.string().describe('OAuth code from CF dash'),
state: z.string().describe('Value of the OAuth state'),
Expand Down Expand Up @@ -220,14 +207,28 @@ export function createAuthHandlers({
if (!oauthReqInfo.clientId) {
return c.text('Invalid request', 400)
}
const res = await getAuthorizationURL({

if (!oauthReqInfo.state) {
return c.text('Invalid request, missing state', 400)
}

const { authUrl, codeVerifier } = await getAuthorizationURL({
client_id: c.env.CLOUDFLARE_CLIENT_ID,
redirect_uri: new URL('/oauth/callback', c.req.url).href,
state: oauthReqInfo,
state: oauthReqInfo.state,
scopes,
})

return Response.redirect(res.authUrl, 302)
// Store the entire auth request and code verifier in a secure, http-only cookie
const cookiePayload = JSON.stringify({ ...oauthReqInfo, codeVerifier })
setCookie(c, 'cloudflare_oauth_request', cookiePayload, {
path: '/',
secure: true,
httpOnly: true,
sameSite: 'Lax',
})

return c.redirect(authUrl, 302)
} catch (e) {
c.var.sentry?.recordError(e)
let message: string | undefined
Expand Down Expand Up @@ -255,21 +256,35 @@ export function createAuthHandlers({
* OAuth Callback Endpoint
*
* This route handles the callback from Cloudflare after user authentication.
* It exchanges the temporary code for an access token, then stores some
* user metadata & the auth token as part of the 'props' on the token passed
* It reads the AuthRequest object from the cookie, validates the state for CSRF protection,
* and then uses the code_verifier to exchange the temporary code for an access token.
* It then stores some user metadata & the auth token as part of the 'props' on the token passed
* down to the client. It ends by redirecting the client back to _its_ callback URL
*/
app.get(`/oauth/callback`, zValidator('query', AuthQuery), async (c) => {
try {
const { state, code } = c.req.valid('query')
const oauthReqInfo = AuthRequestSchemaWithExtraParams.parse(JSON.parse(atob(state)))
// Get the oathReqInfo out of KV
if (!oauthReqInfo.clientId) {
const cookiePayload = getCookie(c, 'cloudflare_oauth_request')

if (!cookiePayload) {
throw new McpError('Missing auth request cookie', 400)
}

const { codeVerifier, ...oauthReqInfo } = JSON.parse(cookiePayload) as AuthRequest & {
codeVerifier: string
}

// Validate the state to prevent CSRF attacks
if (!oauthReqInfo.state || oauthReqInfo.state !== state) {
throw new McpError('Invalid State', 400)
}

if (!codeVerifier) {
throw new McpError('Missing PKCE code verifier in cookie', 400)
}

const [{ accessToken, refreshToken, user, accounts }] = await Promise.all([
getTokenAndUserDetails(c, code, oauthReqInfo.codeVerifier),
getTokenAndUserDetails(c, code, codeVerifier),
c.env.OAUTH_PROVIDER.createClient({
clientId: oauthReqInfo.clientId,
tokenEndpointAuthMethod: 'none',
Expand Down Expand Up @@ -310,7 +325,9 @@ export function createAuthHandlers({
})
)

return Response.redirect(redirectTo, 302)
// Clear the cookie on success
deleteCookie(c, 'cloudflare_oauth_request', { path: '/' })
return c.redirect(redirectTo, 302)
} catch (e) {
c.var.sentry?.recordError(e)
let message: string | undefined
Expand All @@ -327,6 +344,8 @@ export function createAuthHandlers({
errorMessage: `Callback Error: ${message}`,
})
)
// Clear the cookie on error
deleteCookie(c, 'cloudflare_oauth_request', { path: '/' })
if (e instanceof McpError) {
return c.text(e.message, { status: e.code })
}
Expand Down
63 changes: 34 additions & 29 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.