Skip to content
126 changes: 120 additions & 6 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,22 @@ describe("OAuth Authorization", () => {
});

describe("exchangeAuthorization", () => {
const mockProvider: OAuthClientProvider = {
get redirectUrl() { return "http://localhost:3000/callback"; },
get clientMetadata() {
return {
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
},
clientInformation: jest.fn(),
tokens: jest.fn(),
saveTokens: jest.fn(),
redirectToAuthorization: jest.fn(),
saveCodeVerifier: jest.fn(),
codeVerifier: jest.fn(),
};

const validTokens = {
access_token: "access123",
token_type: "Bearer",
Expand Down Expand Up @@ -449,12 +465,11 @@ describe("OAuth Authorization", () => {
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("authorization_code");
expect(body.get("code")).toBe("code123");
Expand All @@ -464,6 +479,50 @@ describe("OAuth Authorization", () => {
expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback");
});

it("exchanges code for tokens with auth", async () => {
mockProvider.authToTokenEndpoint = function(url: URL, headers: Headers, params: URLSearchParams) {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_url", url.toString());
params.set("example_param", "example_value");
};

mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokens,
});

const tokens = await exchangeAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
redirectUri: "http://localhost:3000/callback",
}, mockProvider);

expect(tokens).toEqual(validTokens);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw==");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("authorization_code");
expect(body.get("code")).toBe("code123");
expect(body.get("code_verifier")).toBe("verifier123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback");
expect(body.get("example_url")).toBe("https://auth.example.com/token");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeUndefined;
});

it("validates token response schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
Expand Down Expand Up @@ -502,6 +561,22 @@ describe("OAuth Authorization", () => {
});

describe("refreshAuthorization", () => {
const mockProvider: OAuthClientProvider = {
get redirectUrl() { return "http://localhost:3000/callback"; },
get clientMetadata() {
return {
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
},
clientInformation: jest.fn(),
tokens: jest.fn(),
saveTokens: jest.fn(),
redirectToAuthorization: jest.fn(),
saveCodeVerifier: jest.fn(),
codeVerifier: jest.fn(),
};

const validTokens = {
access_token: "newaccess123",
token_type: "Bearer",
Expand Down Expand Up @@ -538,19 +613,58 @@ describe("OAuth Authorization", () => {
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("client_secret")).toBe("secret123");
});

it("exchanges refresh token for new tokens with auth", async () => {
mockProvider.authToTokenEndpoint = function(url: URL, headers: Headers, params: URLSearchParams) {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_url", url.toString());
params.set("example_param", "example_value");
};

mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokensWithNewRefreshToken,
});

const tokens = await refreshAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
refreshToken: "refresh123",
}, mockProvider);

expect(tokens).toEqual(validTokensWithNewRefreshToken);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
})
);

const headers = mockFetch.mock.calls[0][1].headers as Headers;
expect(headers.get("Content-Type")).toBe("application/x-www-form-urlencoded");
expect(headers.get("Authorization")).toBe("Basic Y2xpZW50MTIzOnNlY3JldDEyMw==");
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("example_url")).toBe("https://auth.example.com/token");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeUndefined;
});

it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
Expand Down
30 changes: 20 additions & 10 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export interface OAuthClientProvider {
* the authorization result.
*/
codeVerifier(): string | Promise<string>;

authToTokenEndpoint?(url: URL, headers: Headers, params: URLSearchParams): void | Promise<void>;
}

export type AuthResult = "AUTHORIZED" | "REDIRECT";
Expand Down Expand Up @@ -137,7 +139,7 @@ export async function auth(
authorizationCode,
codeVerifier,
redirectUri: provider.redirectUrl,
});
}, provider);

await provider.saveTokens(tokens);
return "AUTHORIZED";
Expand All @@ -153,7 +155,7 @@ export async function auth(
metadata,
clientInformation,
refreshToken: tokens.refresh_token,
});
}, provider);

await provider.saveTokens(newTokens);
return "AUTHORIZED";
Expand Down Expand Up @@ -372,6 +374,7 @@ export async function exchangeAuthorization(
codeVerifier: string;
redirectUri: string | URL;
},
provider?: OAuthClientProvider
): Promise<OAuthTokens> {
const grantType = "authorization_code";

Expand All @@ -392,6 +395,9 @@ export async function exchangeAuthorization(
}

// Exchange code for tokens
const headers = new Headers({
"Content-Type": "application/x-www-form-urlencoded",
});
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
Expand All @@ -400,15 +406,15 @@ export async function exchangeAuthorization(
redirect_uri: String(redirectUri),
});

if (clientInformation.client_secret) {
if (provider?.authToTokenEndpoint) {
provider.authToTokenEndpoint(tokenUrl, headers, params);
} else if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}

const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
headers: headers,
body: params,
});

Expand All @@ -433,6 +439,7 @@ export async function refreshAuthorization(
clientInformation: OAuthClientInformation;
refreshToken: string;
},
provider?: OAuthClientProvider,
): Promise<OAuthTokens> {
const grantType = "refresh_token";

Expand All @@ -453,21 +460,24 @@ export async function refreshAuthorization(
}

// Exchange refresh token
const headers = new Headers({
"Content-Type": "application/x-www-form-urlencoded",
});
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
refresh_token: refreshToken,
});

if (clientInformation.client_secret) {
if (provider?.authToTokenEndpoint) {
provider.authToTokenEndpoint(tokenUrl, headers, params);
} else if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}

const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
headers: headers,
body: params,
});
if (!response.ok) {
Expand Down