Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 32 additions & 44 deletions packages/common/src/fetchApiAuthMiddleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// session API key based authentication.

import 'cross-fetch/polyfill';
import { FetchMiddleware, ResponseContext } from './fetchUtil';
import { FetchMiddleware, hostMatches, ResponseContext } from './fetchUtil';

export interface SessionAuthDataStore {
get(host: string): Promise<{ authKey: string } | undefined> | { authKey: string } | undefined;
Expand All @@ -27,59 +27,47 @@ export function createApiSessionAuthMiddleware({
authRequestMetadata = {},
authDataStore = createInMemoryAuthDataStore(),
}: ApiSessionAuthMiddlewareOpts): FetchMiddleware {
// Local temporary cache of auth request promises, used so that multiple re-auth requests are
// not running in parallel. Key is the request host.
const pendingAuthRequests = new Map<string, Promise<{ authKey: string }>>();
// Local temporary cache of any previous auth request promise, used so that
// multiple re-auth requests are not running in parallel.
let pendingAuthRequest: Promise<{ authKey: string }> | null = null;

const authMiddleware: FetchMiddleware = {
pre: async context => {
// Skip middleware if host does not match pattern
const reqUrl = new URL(context.url);
let hostMatches = false;
if (typeof host === 'string') {
hostMatches = host === reqUrl.host;
} else {
hostMatches = !!host.exec(reqUrl.host);
}
if (hostMatches) {
const authData = await authDataStore.get(reqUrl.host);
if (authData) {
context.init.headers = setRequestHeader(context.init, httpHeader, authData.authKey);
}
if (!hostMatches(reqUrl.host, host)) return;

const authData = await authDataStore.get(reqUrl.host);
if (authData) {
context.init.headers = setRequestHeader(context.init, httpHeader, authData.authKey);
}
},
post: async context => {
// Skip middleware if response was successful
if (context.response.status !== 401) return;

// Skip middleware if host does not match pattern
const reqUrl = new URL(context.url);
let hostMatches = false;
if (typeof host === 'string') {
hostMatches = host === reqUrl.host;
} else {
hostMatches = !!host.exec(reqUrl.host);
}
if (!hostMatches(reqUrl.host, host)) return;

// If request is for configured host, and response was `401 Unauthorized`,
// then request auth key and retry request.
if (hostMatches && context.response.status === 401) {
// Check if for any currently pending auth requests and re-use it to avoid creating multiple in parallel.
let pendingAuthRequest = pendingAuthRequests.get(reqUrl.host);
if (!pendingAuthRequest) {
pendingAuthRequest = resolveAuthToken(context, authPath, authRequestMetadata)
.then(async result => {
// If the request is successfull, add the key to storage.
await authDataStore.set(reqUrl.host, result);
return result;
})
.finally(() => {
// When the request is completed (either successful or rejected) clear the promise.
pendingAuthRequests.delete(reqUrl.host);
});
}
const { authKey } = await pendingAuthRequest;
// Retry the request using the new API key auth header.
context.init.headers = setRequestHeader(context.init, httpHeader, authKey);
return context.fetch(context.url, context.init);
} else {
return context.response;
// Retry original request after authorization request
if (!pendingAuthRequest) {
// Check if for any currently pending auth requests and re-use it to avoid creating multiple in parallel
pendingAuthRequest = resolveAuthToken(context, authPath, authRequestMetadata)
.then(async result => {
// If the request is successfull, add the key to storage.
await authDataStore.set(reqUrl.host, result);
return result;
})
.finally(() => {
// When the request is completed (either successful or rejected) remove reference
pendingAuthRequest = null;
});
}
const { authKey } = await pendingAuthRequest;
// Retry the request using the new API key auth header.
context.init.headers = setRequestHeader(context.init, httpHeader, authKey);
return context.fetch(context.url, context.init);
},
};
return authMiddleware;
Expand Down
87 changes: 46 additions & 41 deletions packages/common/src/fetchUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ export interface ApiKeyMiddlewareOpts {
apiKey: string;
}

/** @internal */
export function hostMatches(host: string, pattern: string | RegExp) {
if (typeof pattern === 'string') return pattern === host;
return (pattern as RegExp).exec(host);
}

export function createApiKeyMiddleware({
apiKey,
host = /(.*)api(.*)\.stacks\.co$/i,
Expand All @@ -41,17 +47,11 @@ export function createApiKeyMiddleware({
return {
pre: context => {
const reqUrl = new URL(context.url);
let hostMatches = false;
if (typeof host === 'string') {
hostMatches = host === reqUrl.host;
} else {
hostMatches = !!host.exec(reqUrl.host);
}
if (hostMatches) {
const headers = new Headers(context.init.headers);
headers.set(httpHeader, apiKey);
context.init.headers = headers;
}
if (!hostMatches(reqUrl.host, host)) return; // Skip middleware if host does not match pattern

const headers = new Headers(context.init.headers);
headers.set(httpHeader, apiKey);
context.init.headers = headers;
},
};
}
Expand All @@ -67,46 +67,51 @@ function createDefaultMiddleware(): FetchMiddleware[] {
return [setOriginMiddleware];
}

export function createFetchFn(fetchLib: FetchFn, ...middleware: FetchMiddleware[]): FetchFn;
export function createFetchFn(...middleware: FetchMiddleware[]): FetchFn;
export function createFetchFn(...args: any[]): FetchFn {
// Argument helper function for {createFetchFn}
function argsForCreateFetchFn(args: any[]): { middlewares: FetchMiddleware[]; fetchLib: FetchFn } {
let fetchLib: FetchFn = fetch;
let middlewareOpt: FetchMiddleware[] = [];
if (args.length > 0) {
if (typeof args[0] === 'function') {
fetchLib = args.shift();
}
let middlewares: FetchMiddleware[] = [];
if (args.length > 0 && typeof args[0] === 'function') {
fetchLib = args.shift();
}
if (args.length > 0) {
middlewareOpt = args;
middlewares = args;
}
const middlewares = [...createDefaultMiddleware(), ...middlewareOpt];
return { middlewares, fetchLib };
}

export function createFetchFn(fetchLib: FetchFn, ...middleware: FetchMiddleware[]): FetchFn;
export function createFetchFn(...middleware: FetchMiddleware[]): FetchFn;
export function createFetchFn(...args: any[]): FetchFn {
const { middlewares: middlewareArgs, fetchLib } = argsForCreateFetchFn(args);
const middlewares = [...createDefaultMiddleware(), ...middlewareArgs];

const fetchFn = async (url: string, init?: RequestInit | undefined): Promise<Response> => {
let fetchParams = { url, init: init ?? {} };
for (const middleware of middlewares) {
if (middleware.pre) {
const result = await Promise.resolve(
middleware.pre({
fetch: fetchLib,
...fetchParams,
})
);
fetchParams = result ?? fetchParams;
}
if (typeof middleware.pre !== 'function') continue;
const result = await Promise.resolve(
middleware.pre({
fetch: fetchLib,
...fetchParams,
})
);
fetchParams = result ?? fetchParams;
}

let response = await fetchLib(fetchParams.url, fetchParams.init);

for (const middleware of middlewares) {
if (middleware.post) {
const result = await Promise.resolve(
middleware.post({
fetch: fetchLib,
url: fetchParams.url,
init: fetchParams.init,
response: response.clone(),
})
);
response = result ?? response;
}
if (typeof middleware.post !== 'function') continue;
const result = await Promise.resolve(
middleware.post({
fetch: fetchLib,
url: fetchParams.url,
init: fetchParams.init,
response: response.clone(),
})
);
response = result ?? response;
}
return response;
};
Expand Down