-
Notifications
You must be signed in to change notification settings - Fork 10.5k
fix(SignalR): retry access token refresh on 401 in TS client #63740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
873a113
857315c
f88b82c
0b490cf
372a70a
75ca574
665288d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,12 +1,12 @@ | ||||||||||||||
// Licensed to the .NET Foundation under one or more agreements. | ||||||||||||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||||||||||||
|
||||||||||||||
import { HeaderNames } from "./HeaderNames"; | ||||||||||||||
// Removed HeaderNames import to reduce bundle size; using literal key. | ||||||||||||||
import { HttpClient, HttpRequest, HttpResponse } from "./HttpClient"; | ||||||||||||||
|
||||||||||||||
/** @private */ | ||||||||||||||
export class AccessTokenHttpClient extends HttpClient { | ||||||||||||||
private _innerClient: HttpClient; | ||||||||||||||
private readonly _innerClient: HttpClient; | ||||||||||||||
_accessToken: string | undefined; | ||||||||||||||
_accessTokenFactory: (() => string | Promise<string>) | undefined; | ||||||||||||||
|
||||||||||||||
|
@@ -18,36 +18,38 @@ export class AccessTokenHttpClient extends HttpClient { | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
public async send(request: HttpRequest): Promise<HttpResponse> { | ||||||||||||||
let allowRetry = true; | ||||||||||||||
if (this._accessTokenFactory && (!this._accessToken || (request.url && request.url.indexOf("/negotiate?") > 0))) { | ||||||||||||||
// don't retry if the request is a negotiate or if we just got a potentially new token from the access token factory | ||||||||||||||
allowRetry = false; | ||||||||||||||
this._accessToken = await this._accessTokenFactory(); | ||||||||||||||
} | ||||||||||||||
const needsToken = !!(this._accessTokenFactory && (!this._accessToken || (request.url && request.url.indexOf('/negotiate?') > 0))); | ||||||||||||||
const retry = !needsToken; | ||||||||||||||
if (needsToken) this._accessToken = await this._accessTokenFactory!(); | ||||||||||||||
this._setAuthorizationHeader(request); | ||||||||||||||
const response = await this._innerClient.send(request); | ||||||||||||||
try { | ||||||||||||||
const r = await this._innerClient.send(request); | ||||||||||||||
return (retry && this._accessTokenFactory && r.statusCode === 401) ? this._refreshTokenAndRetry(request, r) : r; | ||||||||||||||
} catch (err: unknown) { | ||||||||||||||
if (!retry || !this._accessTokenFactory) throw err; | ||||||||||||||
const e = err as any, s = +(e.statusCode ?? e.status); | ||||||||||||||
if (s === 401) return this._refreshTokenAndRetry(request, e); | ||||||||||||||
throw err; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
if (allowRetry && response.statusCode === 401 && this._accessTokenFactory) { | ||||||||||||||
this._accessToken = await this._accessTokenFactory(); | ||||||||||||||
this._setAuthorizationHeader(request); | ||||||||||||||
return await this._innerClient.send(request); | ||||||||||||||
private async _refreshTokenAndRetry(request: HttpRequest, o: HttpResponse | Error): Promise<HttpResponse> { | ||||||||||||||
const t = await this._accessTokenFactory!(); | ||||||||||||||
if (!t) { | ||||||||||||||
this._accessToken = undefined; | ||||||||||||||
if (request.headers) delete (request.headers as any).Authorization; | ||||||||||||||
if (request.abortSignal) return this._innerClient.send(request); | ||||||||||||||
if (o instanceof HttpResponse) return o; | ||||||||||||||
return Promise.reject(o); | ||||||||||||||
} | ||||||||||||||
return response; | ||||||||||||||
this._accessToken = t; | ||||||||||||||
this._setAuthorizationHeader(request); | ||||||||||||||
return this._innerClient.send(request); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
private _setAuthorizationHeader(request: HttpRequest) { | ||||||||||||||
if (!request.headers) { | ||||||||||||||
request.headers = {}; | ||||||||||||||
} | ||||||||||||||
if (this._accessToken) { | ||||||||||||||
request.headers[HeaderNames.Authorization] = `Bearer ${this._accessToken}` | ||||||||||||||
} | ||||||||||||||
// don't remove the header if there isn't an access token factory, the user manually added the header in this case | ||||||||||||||
else if (this._accessTokenFactory) { | ||||||||||||||
if (request.headers[HeaderNames.Authorization]) { | ||||||||||||||
delete request.headers[HeaderNames.Authorization]; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
const h = request.headers || (request.headers = {}); | ||||||||||||||
if (this._accessToken) h.Authorization = 'Bearer ' + this._accessToken; else if (this._accessTokenFactory) delete h.Authorization; | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] This complex conditional logic should be split into multiple lines for better readability. Consider using separate if-else blocks to make the authorization header logic clearer.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||
} | ||||||||||||||
|
||||||||||||||
public getCookieString(url: string): string { | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,247 @@ | ||||||
// Licensed to the .NET Foundation under one or more agreements. | ||||||
// The .NET Foundation licenses this file to you under the MIT license. | ||||||
|
||||||
import { AccessTokenHttpClient } from "../src/AccessTokenHttpClient"; | ||||||
import { HttpError } from "../src/Errors"; | ||||||
import { HttpRequest, HttpResponse } from "../src/HttpClient"; | ||||||
import { TestHttpClient } from "./TestHttpClient"; | ||||||
import { registerUnhandledRejectionHandler } from "./Utils"; | ||||||
import { VerifyLogger } from "./Common"; | ||||||
import { LongPollingTransport } from "../src/LongPollingTransport"; | ||||||
import { TransferFormat } from "../src/ITransport"; | ||||||
|
||||||
describe("AccessTokenHttpClient", () => { | ||||||
beforeAll(() => { | ||||||
registerUnhandledRejectionHandler(); | ||||||
}); | ||||||
|
||||||
afterAll(() => { | ||||||
// Optional cleanup could go here. | ||||||
}); | ||||||
|
||||||
it("retries exactly once on 401 HttpError when accessTokenFactory provided", async () => { | ||||||
let call = 0; | ||||||
let primed = false; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on(() => { | ||||||
if (!primed) { | ||||||
primed = true; // prime request returns 200 and sets initial token | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
call++; | ||||||
if (call === 1) { | ||||||
throw new HttpError("Unauthorized", 401); | ||||||
} | ||||||
return new HttpResponse(200, "OK", "done"); | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
return `token${factoryCalls}`; | ||||||
}); | ||||||
|
||||||
// Prime token via public API | ||||||
await client.get("http://example.com/prime"); | ||||||
|
||||||
const response = await client.get("http://example.com/resource"); | ||||||
expect(response.statusCode).toBe(200); | ||||||
expect(factoryCalls).toBe(2); // prime + retry refresh | ||||||
expect(call).toBe(2); // failing attempt + successful retry | ||||||
}); | ||||||
|
||||||
[403, 500].forEach(status => { | ||||||
it(`does not retry on status ${status} HttpError`, async () => { | ||||||
let primed = false; | ||||||
let failingCalls = 0; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on(() => { | ||||||
if (!primed) { | ||||||
primed = true; | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
failingCalls++; | ||||||
throw new HttpError("Error", status); | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
return `token${factoryCalls}`; | ||||||
}); | ||||||
|
||||||
await client.get("http://example.com/prime"); | ||||||
try { | ||||||
await expect(client.get("http://example.com/resource")).rejects.toThrow(HttpError); | ||||||
} catch (e: any) { | ||||||
expect(e).toBeInstanceOf(HttpError); | ||||||
expect(e.statusCode ?? e.status).toBe(status); | ||||||
} | ||||||
expect(factoryCalls).toBe(1); | ||||||
expect(failingCalls).toBe(1); | ||||||
}); | ||||||
}); | ||||||
|
||||||
it("LongPollingTransport continues running after 401 during poll and refreshes token", async () => { | ||||||
await VerifyLogger.run(async (logger) => { | ||||||
let pollIteration = 0; | ||||||
let primed = false; | ||||||
const tokens: string[] = []; | ||||||
const accessTokenFactory = () => { | ||||||
const t = `tok${tokens.length + 1}`; | ||||||
tokens.push(t); | ||||||
return t; | ||||||
}; | ||||||
const httpClient = new AccessTokenHttpClient(new TestHttpClient() | ||||||
.on("GET", (r: HttpRequest) => { | ||||||
// Prime request separate from polling loop | ||||||
if (!primed && r.url!.includes("/prime")) { | ||||||
primed = true; | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
pollIteration++; | ||||||
if (pollIteration === 1) { // initial connect poll | ||||||
return new HttpResponse(200, "OK", ""); | ||||||
} | ||||||
if (pollIteration === 2) { // trigger 401 -> retry | ||||||
return new HttpResponse(401); | ||||||
} | ||||||
if (pollIteration === 3) { // post-refresh poll | ||||||
expect(r.headers).toBeDefined(); | ||||||
expect(r.headers?.Authorization).toBeDefined(); | ||||||
expect(r.headers?.Authorization).toContain(tokens[tokens.length - 1]); | ||||||
return new HttpResponse(204); | ||||||
} | ||||||
return new HttpResponse(204); | ||||||
}), accessTokenFactory); | ||||||
|
||||||
// Prime token using public API | ||||||
await httpClient.get("http://example.com/prime"); | ||||||
|
||||||
const transport = new LongPollingTransport(httpClient, logger, { withCredentials: true, headers: {}, logMessageContent: false }); | ||||||
await transport.connect("http://example.com?connectionId=abc", TransferFormat.Text); | ||||||
await transport.stop(); | ||||||
|
||||||
expect(tokens.length).toBe(2); // primed + refreshed | ||||||
expect(pollIteration).toBeGreaterThanOrEqual(3); | ||||||
}); | ||||||
}); | ||||||
|
||||||
it("retries once on 401 HttpResponse status (non-throwing path)", async () => { | ||||||
let primed = false; | ||||||
let attempts = 0; | ||||||
let retryAuthHeader: string | undefined; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on((r: HttpRequest) => { | ||||||
if (!primed && r.url!.includes("/prime")) { | ||||||
primed = true; | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
attempts++; | ||||||
if (attempts === 1) { | ||||||
return new HttpResponse(401); | ||||||
} | ||||||
// second attempt after refresh | ||||||
retryAuthHeader = r.headers?.Authorization; | ||||||
return new HttpResponse(200, "OK", "after-retry"); | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
return `token${factoryCalls}`; | ||||||
}); | ||||||
|
||||||
await client.get("http://example.com/prime"); | ||||||
const resp = await client.get("http://example.com/resource"); | ||||||
expect(resp.statusCode).toBe(200); | ||||||
expect(factoryCalls).toBe(2); // prime + refresh | ||||||
expect(attempts).toBe(2); // original 401 + retry 200 | ||||||
expect(retryAuthHeader).toContain("token2"); | ||||||
}); | ||||||
|
||||||
it("does not retry when allowRetry is false (initial token acquisition)", async () => { | ||||||
let sends = 0; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on(() => { | ||||||
sends++; | ||||||
return new HttpResponse(401); | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
return `token${factoryCalls}`; // Token factory returns a token string for each call. | ||||||
}); | ||||||
|
||||||
const request: HttpRequest = { method: "GET", url: "http://example.com/resource" }; | ||||||
const resp = await client.send(request); // send path with existing logic; allowRetry=false triggered by initial token acquisition above | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment mentions 'allowRetry=false triggered by initial token acquisition above' but there's no token acquisition visible in this test method above this line. The comment appears to be inaccurate or misleading.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
expect(resp.statusCode).toBe(401); | ||||||
expect(factoryCalls).toBe(1); | ||||||
expect(sends).toBe(1); | ||||||
}); | ||||||
|
||||||
it("does not retry when refreshed token is empty", async () => { | ||||||
let primed = false; | ||||||
let attempts = 0; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on((r: HttpRequest) => { | ||||||
if (!primed && r.url!.includes("/prime")) { | ||||||
primed = true; | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
attempts++; | ||||||
return new HttpResponse(401); // cause retry path | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
if (factoryCalls === 1) { | ||||||
return "tok1"; // prime | ||||||
} | ||||||
return ""; // refresh returns empty -> should not retry send again | ||||||
}); | ||||||
|
||||||
await client.get("http://example.com/prime"); | ||||||
const resp = await client.get("http://example.com/resource"); | ||||||
expect(resp.statusCode).toBe(401); // original response returned | ||||||
expect(factoryCalls).toBe(2); // prime + attempted refresh | ||||||
expect(attempts).toBe(1); // no second send | ||||||
}); | ||||||
|
||||||
it("retries once when HttpError.status is string '401'", async () => { | ||||||
let primed = false; | ||||||
let attempt = 0; | ||||||
let retryAuth: string | undefined; | ||||||
const inner = new TestHttpClient(); | ||||||
inner.on((r: HttpRequest) => { | ||||||
if (!primed && r.url!.includes("/prime")) { | ||||||
primed = true; | ||||||
return new HttpResponse(200, "OK", "prime"); | ||||||
} | ||||||
attempt++; | ||||||
if (attempt === 1) { | ||||||
const err: any = new Error("Unauthorized: Status code '401'"); | ||||||
err.name = "HttpError"; // mimic HttpError shape without statusCode | ||||||
err.status = "401"; // string status to trigger normalization path | ||||||
throw err; | ||||||
} | ||||||
retryAuth = r.headers?.Authorization; | ||||||
return new HttpResponse(200, "OK", "ok"); | ||||||
}); | ||||||
|
||||||
let factoryCalls = 0; | ||||||
const client = new AccessTokenHttpClient(inner, () => { | ||||||
factoryCalls++; | ||||||
return `token${factoryCalls}`; | ||||||
}); | ||||||
|
||||||
await client.get("http://example.com/prime"); | ||||||
const resp = await client.get("http://example.com/resource"); | ||||||
expect(resp.statusCode).toBe(200); | ||||||
expect(factoryCalls).toBe(2); // prime + refresh after string status retry | ||||||
expect(attempt).toBe(2); // original throw + retry | ||||||
expect(retryAuth).toContain("token2"); | ||||||
}); | ||||||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
request.abortSignal
checks for the existence of the signal, not if it's been aborted. This should likely checkrequest.abortSignal?.aborted
to determine if the request was cancelled, or this logic may be incorrect for the intended behavior.Copilot uses AI. Check for mistakes.