Skip to content
61 changes: 55 additions & 6 deletions src/SignalR/clients/ts/signalr/src/AccessTokenHttpClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
import { HeaderNames } from "./HeaderNames";
import { HttpClient, HttpRequest, HttpResponse } from "./HttpClient";

// Internal helpers (not exported) for narrowing and status normalization
function isError(u: unknown): u is Error {
return u instanceof Error;
}
function getStatus(u: unknown): number | undefined {
if (typeof u !== "object" || u === null) { return undefined; }
const rec = u as Record<string, unknown>;
const raw = rec["statusCode"] ?? rec["status"];
if (typeof raw === "number") { return raw; }
if (typeof raw === "string") {
const n = parseInt(raw, 10);
return Number.isNaN(n) ? undefined : n;
}
return undefined;
}

/** @private */
export class AccessTokenHttpClient extends HttpClient {
private _innerClient: HttpClient;
Expand All @@ -25,14 +41,47 @@ export class AccessTokenHttpClient extends HttpClient {
this._accessToken = await this._accessTokenFactory();
}
this._setAuthorizationHeader(request);
const response = await this._innerClient.send(request);

if (allowRetry && response.statusCode === 401 && this._accessTokenFactory) {
this._accessToken = await this._accessTokenFactory();
this._setAuthorizationHeader(request);
return await this._innerClient.send(request);
try {
const response = await this._innerClient.send(request);

if (allowRetry && this._accessTokenFactory && response.statusCode === 401) {
return await this._refreshTokenAndRetry(request, response);
}
return response;
} catch (err: unknown) {
if (!allowRetry || !this._accessTokenFactory) {
throw err;
}
if (!isError(err)) {
throw err;
}
const status = getStatus(err);
if (status === 401) {
return await this._refreshTokenAndRetry(request, err);
}
throw err;
}
return response;
}

private async _refreshTokenAndRetry(request: HttpRequest, original: HttpResponse | Error): Promise<HttpResponse> {
if (!this._accessTokenFactory) {
if (original instanceof HttpResponse) {
return original;
}
throw original;
}

const newToken = await this._accessTokenFactory();
if (!newToken) {
if (original instanceof HttpResponse) {
return original;
}
throw original;
}
this._accessToken = newToken;
this._setAuthorizationHeader(request);
return await this._innerClient.send(request);
}

private _setAuthorizationHeader(request: HttpRequest) {
Expand Down
248 changes: 248 additions & 0 deletions src/SignalR/clients/ts/signalr/tests/AccessTokenHttpClient.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

import { AccessTokenHttpClient } from "../src/AccessTokenHttpClient";
import { HttpError } from "../src/Errors";
import { HttpRequest, HttpResponse } from "../src/HttpClient";
import { TestHttpClient } from "./TestHttpClient";
import { registerUnhandledRejectionHandler } from "./Utils";
import { VerifyLogger } from "./Common";
import { LongPollingTransport } from "../src/LongPollingTransport";
import { TransferFormat } from "../src/ITransport";

describe("AccessTokenHttpClient", () => {
beforeAll(() => {
registerUnhandledRejectionHandler();
});

afterAll(() => {
// Optional cleanup could go here.
});

it("retries exactly once on 401 HttpError when accessTokenFactory provided", async () => {
let call = 0;
let primed = false;
const inner = new TestHttpClient();
inner.on(() => {
if (!primed) {
primed = true; // prime request returns 200 and sets initial token
return new HttpResponse(200, "OK", "prime");
}
call++;
if (call === 1) {
throw new HttpError("Unauthorized", 401);
}
return new HttpResponse(200, "OK", "done");
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
return `token${factoryCalls}`;
});

// Prime token via public API
await client.get("http://example.com/prime");

const response = await client.get("http://example.com/resource");
expect(response.statusCode).toBe(200);
expect(factoryCalls).toBe(2); // prime + retry refresh
expect(call).toBe(2); // failing attempt + successful retry
});

[403, 500].forEach(status => {
it(`does not retry on status ${status} HttpError`, async () => {
let primed = false;
let failingCalls = 0;
const inner = new TestHttpClient();
inner.on(() => {
if (!primed) {
primed = true;
return new HttpResponse(200, "OK", "prime");
}
failingCalls++;
throw new HttpError("Error", status);
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
return `token${factoryCalls}`;
});

await client.get("http://example.com/prime");
try {
await client.get("http://example.com/resource");
fail("expected to throw");
} catch (e: any) {
expect(e).toBeInstanceOf(HttpError);
expect(e.statusCode ?? e.status).toBe(status);
}
expect(factoryCalls).toBe(1);
expect(failingCalls).toBe(1);
});
});

it("LongPollingTransport continues running after 401 during poll and refreshes token", async () => {
await VerifyLogger.run(async (logger) => {
let pollIteration = 0;
let primed = false;
const tokens: string[] = [];
const accessTokenFactory = () => {
const t = `tok${tokens.length + 1}`;
tokens.push(t);
return t;
};
const httpClient = new AccessTokenHttpClient(new TestHttpClient()
.on("GET", (r: HttpRequest) => {
// Prime request separate from polling loop
if (!primed && r.url!.includes("/prime")) {
primed = true;
return new HttpResponse(200, "OK", "prime");
}
pollIteration++;
if (pollIteration === 1) { // initial connect poll
return new HttpResponse(200, "OK", "");
}
if (pollIteration === 2) { // trigger 401 -> retry
return new HttpResponse(401);
}
if (pollIteration === 3) { // post-refresh poll
expect(r.headers).toBeDefined();
expect(r.headers?.Authorization).toBeDefined();
expect(r.headers?.Authorization).toContain(tokens[tokens.length - 1]);
return new HttpResponse(204);
}
return new HttpResponse(204);
}), accessTokenFactory);

// Prime token using public API
await httpClient.get("http://example.com/prime");

const transport = new LongPollingTransport(httpClient, logger, { withCredentials: true, headers: {}, logMessageContent: false });
await transport.connect("http://example.com?connectionId=abc", TransferFormat.Text);
await transport.stop();

expect(tokens.length).toBe(2); // primed + refreshed
expect(pollIteration).toBeGreaterThanOrEqual(3);
});
});

it("retries once on 401 HttpResponse status (non-throwing path)", async () => {
let primed = false;
let attempts = 0;
let retryAuthHeader: string | undefined;
const inner = new TestHttpClient();
inner.on((r: HttpRequest) => {
if (!primed && r.url!.includes("/prime")) {
primed = true;
return new HttpResponse(200, "OK", "prime");
}
attempts++;
if (attempts === 1) {
return new HttpResponse(401);
}
// second attempt after refresh
retryAuthHeader = r.headers?.Authorization;
return new HttpResponse(200, "OK", "after-retry");
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
return `token${factoryCalls}`;
});

await client.get("http://example.com/prime");
const resp = await client.get("http://example.com/resource");
expect(resp.statusCode).toBe(200);
expect(factoryCalls).toBe(2); // prime + refresh
expect(attempts).toBe(2); // original 401 + retry 200
expect(retryAuthHeader).toContain("token2");
});

it("does not retry when allowRetry is false (initial token acquisition)", async () => {
let sends = 0;
const inner = new TestHttpClient();
inner.on(() => {
sends++;
return new HttpResponse(401);
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
return `token${factoryCalls}`; // Explicitly call send with allowRetry=false to ensure no retry is attempted.
});

const request: HttpRequest = { method: "GET", url: "http://example.com/resource" };
const resp = await client.send(request); // send path with existing logic; allowRetry=false triggered by initial token acquisition above
Copy link

Copilot AI Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment mentions 'allowRetry=false triggered by initial token acquisition above' but there's no token acquisition visible in this test method above this line. The comment appears to be inaccurate or misleading.

Suggested change
const resp = await client.send(request); // send path with existing logic; allowRetry=false triggered by initial token acquisition above
const resp = await client.send(request); // Verifies that allowRetry=false prevents retry on initial token acquisition (401 response).

Copilot uses AI. Check for mistakes.

expect(resp.statusCode).toBe(401);
expect(factoryCalls).toBe(1);
expect(sends).toBe(1);
});

it("does not retry when refreshed token is empty", async () => {
let primed = false;
let attempts = 0;
const inner = new TestHttpClient();
inner.on((r: HttpRequest) => {
if (!primed && r.url!.includes("/prime")) {
primed = true;
return new HttpResponse(200, "OK", "prime");
}
attempts++;
return new HttpResponse(401); // cause retry path
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
if (factoryCalls === 1) {
return "tok1"; // prime
}
return ""; // refresh returns empty -> should not retry send again
});

await client.get("http://example.com/prime");
const resp = await client.get("http://example.com/resource");
expect(resp.statusCode).toBe(401); // original response returned
expect(factoryCalls).toBe(2); // prime + attempted refresh
expect(attempts).toBe(1); // no second send
});

it("retries once when HttpError.status is string '401'", async () => {
let primed = false;
let attempt = 0;
let retryAuth: string | undefined;
const inner = new TestHttpClient();
inner.on((r: HttpRequest) => {
if (!primed && r.url!.includes("/prime")) {
primed = true;
return new HttpResponse(200, "OK", "prime");
}
attempt++;
if (attempt === 1) {
const err: any = new Error("Unauthorized: Status code '401'");
err.name = "HttpError"; // mimic HttpError shape without statusCode
err.status = "401"; // string status to trigger normalization path
throw err;
}
retryAuth = r.headers?.Authorization;
return new HttpResponse(200, "OK", "ok");
});

let factoryCalls = 0;
const client = new AccessTokenHttpClient(inner, () => {
factoryCalls++;
return `token${factoryCalls}`;
});

await client.get("http://example.com/prime");
const resp = await client.get("http://example.com/resource");
expect(resp.statusCode).toBe(200);
expect(factoryCalls).toBe(2); // prime + refresh after string status retry
expect(attempt).toBe(2); // original throw + retry
expect(retryAuth).toContain("token2");
});
});
Loading