Skip to content

Commit cce637c

Browse files
fix(proxy): improve redirect proxy logic (#11517)
* refactor(proxy): improve redirect proxy code * match name * correct comment * remove unused type import * improve * more cleanups * simplify * simpler * refactor * simplify * more * forward all params * add/fix tests * fix comment * drop .only * add comment
1 parent f6b7228 commit cce637c

File tree

6 files changed

+287
-283
lines changed

6 files changed

+287
-283
lines changed

packages/core/src/jwt.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ export async function encode<Payload = JWT>(params: JWTEncodeParams<Payload>) {
7171
.encrypt(encryptionSecret)
7272
}
7373

74-
/** Decodes a Auth.js issued JWT. */
74+
/** Decodes an Auth.js issued JWT. */
7575
export async function decode<Payload = JWT>(
7676
params: JWTDecodeParams
7777
): Promise<Payload | null> {

packages/core/src/lib/actions/callback/index.ts

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {
1010
} from "../../../errors.js"
1111
import { handleLoginOrRegister } from "./handle-login.js"
1212
import { handleOAuth } from "./oauth/callback.js"
13-
import { handleState } from "./oauth/checks.js"
13+
import { state } from "./oauth/checks.js"
1414
import { createHash } from "../../utils/web.js"
1515

1616
import type { AdapterSession } from "../../../adapters.js"
@@ -57,28 +57,33 @@ export async function callback(
5757
try {
5858
if (provider.type === "oauth" || provider.type === "oidc") {
5959
// Use body if the response mode is set to form_post. For all other cases, use query
60-
const payload =
60+
const params =
6161
provider.authorization?.url.searchParams.get("response_mode") ===
6262
"form_post"
6363
? body
6464
: query
6565

66-
const { proxyRedirect, randomState } = handleState(
67-
payload,
68-
provider,
69-
options.isOnRedirectProxy
70-
)
71-
72-
if (proxyRedirect) {
73-
logger.debug("proxy redirect", { proxyRedirect, randomState })
74-
return { redirect: proxyRedirect }
66+
// If we have a state and we are on a redirect proxy, we try to parse it
67+
// and see if it contains a valid origin to redirect to. If it does, we
68+
// redirect the user to that origin with the original state.
69+
if (options.isOnRedirectProxy && params?.state) {
70+
// NOTE: We rely on the state being encrypted using a shared secret
71+
// between the proxy and the original server.
72+
const parsedState = await state.decode(params.state, options)
73+
const shouldRedirect =
74+
parsedState?.origin &&
75+
new URL(parsedState.origin).origin !== options.url.origin
76+
if (shouldRedirect) {
77+
const proxyRedirect = `${parsedState.origin}?${new URLSearchParams(params)}`
78+
logger.debug("Proxy redirecting to", proxyRedirect)
79+
return { redirect: proxyRedirect, cookies }
80+
}
7581
}
7682

7783
const authorizationResult = await handleOAuth(
78-
payload,
84+
params,
7985
request.cookies,
80-
options,
81-
randomState
86+
options
8287
)
8388

8489
if (authorizationResult.cookies.length) {

packages/core/src/lib/actions/callback/oauth/callback.ts

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ import { isOIDCProvider } from "../../../utils/providers.js"
2828
* we fetch it anyway. This is because we always want a user profile.
2929
*/
3030
export async function handleOAuth(
31-
query: RequestInternal["query"],
31+
params: RequestInternal["query"],
3232
cookies: RequestInternal["cookies"],
33-
options: InternalOptions<"oauth" | "oidc">,
34-
randomState?: string
33+
options: InternalOptions<"oauth" | "oidc">
3534
) {
3635
const { logger, provider } = options
3736
let as: o.AuthorizationServer
@@ -78,17 +77,12 @@ export async function handleOAuth(
7877

7978
const resCookies: Cookie[] = []
8079

81-
const state = await checks.state.use(
82-
cookies,
83-
resCookies,
84-
options,
85-
randomState
86-
)
80+
const state = await checks.state.use(cookies, resCookies, options)
8781

8882
const codeGrantParams = o.validateAuthResponse(
8983
as,
9084
client,
91-
new URLSearchParams(query),
85+
new URLSearchParams(params),
9286
provider.checks.includes("state") ? state : o.skipStateCheck
9387
)
9488

0 commit comments

Comments
 (0)