diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index bf41b5ebd..c165fe7ff 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -322,7 +322,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'valid_code', - code_verifier: 'any_verifier' + code_verifier: 'any_verifier', + redirect_uri: 'https://example.com/callback' }); expect(response.status).toBe(200); @@ -342,6 +343,69 @@ describe('Token Handler', () => { global.fetch = originalFetch; } }); + + it('passes through redirect_uri when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async (token) => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async (clientId) => clientId === 'valid-client' ? validClient : undefined + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const redirectUri = 'https://example.com/callback'; + const response = await supertest(proxyApp) + .post('/token') + .type('form') + .send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier', + redirect_uri: redirectUri + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + } finally { + global.fetch = originalFetch; + } + }); }); describe('Refresh token grant', () => { diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index 28412a014..eadbd7515 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -31,6 +31,7 @@ const TokenRequestSchema = z.object({ const AuthorizationCodeGrantSchema = z.object({ code: z.string(), code_verifier: z.string(), + redirect_uri: z.string().optional(), }); const RefreshTokenGrantSchema = z.object({ @@ -88,7 +89,7 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand throw new InvalidRequestError(parseResult.error.message); } - const { code, code_verifier } = parseResult.data; + const { code, code_verifier, redirect_uri } = parseResult.data; const skipLocalPkceValidation = provider.skipLocalPkceValidation; @@ -102,7 +103,12 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand } // Passes the code_verifier to the provider if PKCE validation didn't occur locally - const tokens = await provider.exchangeAuthorizationCode(client, code, skipLocalPkceValidation ? code_verifier : undefined); + const tokens = await provider.exchangeAuthorizationCode( + client, + code, + skipLocalPkceValidation ? code_verifier : undefined, + redirect_uri + ); res.status(200).json(tokens); break; } diff --git a/src/server/auth/provider.ts b/src/server/auth/provider.ts index 8a0bf0f17..7815b713e 100644 --- a/src/server/auth/provider.ts +++ b/src/server/auth/provider.ts @@ -36,7 +36,12 @@ export interface OAuthServerProvider { /** * Exchanges an authorization code for an access token. */ - exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string, codeVerifier?: string): Promise; + exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string, + redirectUri?: string + ): Promise; /** * Exchanges a refresh token for an access token. diff --git a/src/server/auth/providers/proxyProvider.test.ts b/src/server/auth/providers/proxyProvider.test.ts index 6e842ea33..69039c3e0 100644 --- a/src/server/auth/providers/proxyProvider.test.ts +++ b/src/server/auth/providers/proxyProvider.test.ts @@ -142,6 +142,28 @@ describe("Proxy OAuth Server Provider", () => { expect(tokens).toEqual(mockTokenResponse); }); + it("includes redirect_uri in token request when provided", async () => { + const redirectUri = "https://example.com/callback"; + const tokens = await provider.exchangeAuthorizationCode( + validClient, + "test-code", + "test-verifier", + redirectUri + ); + + expect(global.fetch).toHaveBeenCalledWith( + "https://auth.example.com/token", + expect.objectContaining({ + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + it("exchanges refresh token for new tokens", async () => { const tokens = await provider.exchangeRefreshToken( validClient, diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts index be4503050..db7460e55 100644 --- a/src/server/auth/providers/proxyProvider.ts +++ b/src/server/auth/providers/proxyProvider.ts @@ -151,7 +151,8 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { async exchangeAuthorizationCode( client: OAuthClientInformationFull, authorizationCode: string, - codeVerifier?: string + codeVerifier?: string, + redirectUri?: string ): Promise { const params = new URLSearchParams({ grant_type: "authorization_code", @@ -167,6 +168,10 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.append("code_verifier", codeVerifier); } + if (redirectUri) { + params.append("redirect_uri", redirectUri); + } + const response = await fetch(this._endpoints.tokenUrl, { method: "POST", headers: {