diff --git a/packages/plugins/jwt/src/index.ts b/packages/plugins/jwt/src/index.ts index d2a85e21ea..80446010b1 100644 --- a/packages/plugins/jwt/src/index.ts +++ b/packages/plugins/jwt/src/index.ts @@ -70,6 +70,8 @@ export function useJWT(options: JwtPluginOptions): Plugin { const payloadByRequest = new WeakMap(); let jwksClient: JwksClient; + const jwksCache: Map = new Map(); + if (options.jwksUri) { jwksClient = new JwksClient({ cache: true, @@ -83,15 +85,36 @@ export function useJWT(options: JwtPluginOptions): Plugin { async onRequestParse({ request, serverContext, url }) { const token = await getToken({ request, serverContext, url }); if (token != null) { - const signingKey = options.signingKey ?? (await fetchKey(jwksClient, token)); - - const verified = await verify(token, signingKey, options); - - if (!verified) { - throw unauthorizedError(`Unauthenticated`); + try { + const signingKey = + options.signingKey ?? (await fetchKey({ jwksClient, jwksCache, token })); + const verified = await verify(token, signingKey, options); + + if (!verified) { + throw new Error('Initial verification failed.'); + } + + payloadByRequest.set(request, verified); + } catch (error) { + // If error is thrown and signing key was supplied, do not attempt cache refresh + if (options.signingKey) { + throw unauthorizedError(`Unauthenticated`); + } + + // If initial verification fails, attempt to refresh the key and retry verification + const signingKey = await fetchKey({ + jwksClient, + jwksCache, + token, + shouldRefreshCache: true, + }); + const verified = await verify(token, signingKey, options); + if (!verified) { + throw unauthorizedError(`Unauthenticated`); + } + + payloadByRequest.set(request, verified); } - - payloadByRequest.set(request, verified); } }, onContextBuilding({ context, extendContext }) { @@ -142,18 +165,38 @@ function verify( }); } -async function fetchKey(jwksClient: JwksClient, token: string): Promise { +interface FetchKeyOptions { + jwksClient: JwksClient; + jwksCache: Map; + token: string; + shouldRefreshCache?: boolean; +} + +async function fetchKey({ + jwksClient, + jwksCache, + token, + shouldRefreshCache = false, +}: FetchKeyOptions): Promise { const decodedToken = decode(token, { complete: true }); if (decodedToken?.header?.kid == null) { throw unauthorizedError(`Failed to decode authentication token. Missing key id.`); } - const secret = await jwksClient.getSigningKey(decodedToken.header.kid); - const signingKey = secret?.getPublicKey(); - if (!signingKey) { - throw unauthorizedError(`Failed to decode authentication token. Unknown key id.`); + if (shouldRefreshCache) { + jwksCache.delete(decodedToken.header.kid); + } + + if (!jwksCache.has(decodedToken.header.kid)) { + const secret = await jwksClient.getSigningKey(decodedToken.header.kid); + const signingKey = secret?.getPublicKey(); + if (!signingKey) { + throw unauthorizedError(`Unauthenticated`); + } + jwksCache.set(decodedToken.header.kid, signingKey); } - return signingKey; + + return jwksCache.get(decodedToken.header.kid)!; } const defaultGetToken: NonNullable = ({ request }) => {