diff --git a/src/configuration-requestor.spec.ts b/src/configuration-requestor.spec.ts index b3bb05f..31d7f76 100644 --- a/src/configuration-requestor.spec.ts +++ b/src/configuration-requestor.spec.ts @@ -9,8 +9,12 @@ import ApiEndpoints from './api-endpoints'; import ConfigurationRequestor from './configuration-requestor'; import { IConfigurationStore } from './configuration-store/configuration-store'; import { MemoryOnlyConfigurationStore } from './configuration-store/memory.store'; -import FetchHttpClient, { IHttpClient } from './http-client'; -import { BanditVariation, BanditParameters, Flag } from './interfaces'; +import FetchHttpClient, { + IBanditParametersResponse, + IHttpClient, + IUniversalFlagConfigResponse, +} from './http-client'; +import { BanditParameters, BanditVariation, Flag } from './interfaces'; describe('ConfigurationRequestor', () => { let flagStore: IConfigurationStore; @@ -111,13 +115,13 @@ describe('ConfigurationRequestor', () => { describe('Flags with bandits', () => { let fetchSpy: jest.Mock; - beforeAll(() => { + function initiateFetchSpy( + responseMockGenerator: ( + url: string, + ) => IUniversalFlagConfigResponse | IBanditParametersResponse, + ) { fetchSpy = jest.fn((url: string) => { - const responseFile = url.includes('bandits') - ? MOCK_BANDIT_MODELS_RESPONSE_FILE - : MOCK_FLAGS_WITH_BANDITS_RESPONSE_FILE; - const response = readMockUFCResponse(responseFile); - + const response = responseMockGenerator(url); return Promise.resolve({ ok: true, status: 200, @@ -125,83 +129,301 @@ describe('ConfigurationRequestor', () => { }); }) as jest.Mock; global.fetch = fetchSpy; - }); + } - it('Fetches and populates bandit parameters', async () => { - await configurationRequestor.fetchAndStoreConfigurations(); + function defaultResponseMockGenerator(url: string) { + const responseFile = url.includes('bandits') + ? MOCK_BANDIT_MODELS_RESPONSE_FILE + : MOCK_FLAGS_WITH_BANDITS_RESPONSE_FILE; + return readMockUFCResponse(responseFile); + } - expect(fetchSpy).toHaveBeenCalledTimes(2); // Once for UFC, another for bandits - - expect(flagStore.getKeys().length).toBeGreaterThanOrEqual(2); - expect(flagStore.get('banner_bandit_flag')).toBeDefined(); - expect(flagStore.get('cold_start_bandit')).toBeDefined(); - - expect(banditModelStore.getKeys().length).toBeGreaterThanOrEqual(2); - - const bannerBandit = banditModelStore.get('banner_bandit'); - expect(bannerBandit?.banditKey).toBe('banner_bandit'); - expect(bannerBandit?.modelName).toBe('falcon'); - expect(bannerBandit?.modelVersion).toBe('123'); - const bannerModelData = bannerBandit?.modelData; - expect(bannerModelData?.gamma).toBe(1); - expect(bannerModelData?.defaultActionScore).toBe(0); - expect(bannerModelData?.actionProbabilityFloor).toBe(0); - const bannerCoefficients = bannerModelData?.coefficients || {}; - expect(Object.keys(bannerCoefficients).length).toBe(2); - - // Deep dive for the nike action - const nikeCoefficients = bannerCoefficients['nike']; - expect(nikeCoefficients.actionKey).toBe('nike'); - expect(nikeCoefficients.intercept).toBe(1); - expect(nikeCoefficients.actionNumericCoefficients).toHaveLength(1); - const nikeBrandAffinityCoefficient = nikeCoefficients.actionNumericCoefficients[0]; - expect(nikeBrandAffinityCoefficient.attributeKey).toBe('brand_affinity'); - expect(nikeBrandAffinityCoefficient.coefficient).toBe(1); - expect(nikeBrandAffinityCoefficient.missingValueCoefficient).toBe(-0.1); - expect(nikeCoefficients.actionCategoricalCoefficients).toHaveLength(2); - const nikeLoyaltyTierCoefficient = nikeCoefficients.actionCategoricalCoefficients[0]; - expect(nikeLoyaltyTierCoefficient.attributeKey).toBe('loyalty_tier'); - expect(nikeLoyaltyTierCoefficient.missingValueCoefficient).toBe(0); - expect(nikeLoyaltyTierCoefficient.valueCoefficients).toStrictEqual({ - gold: 4.5, - silver: 3.2, - bronze: 1.9, + describe('Fetching bandits', () => { + beforeAll(() => { + initiateFetchSpy(defaultResponseMockGenerator); }); - expect(nikeCoefficients.subjectNumericCoefficients).toHaveLength(1); - const nikeAccountAgeCoefficient = nikeCoefficients.subjectNumericCoefficients[0]; - expect(nikeAccountAgeCoefficient.attributeKey).toBe('account_age'); - expect(nikeAccountAgeCoefficient.coefficient).toBe(0.3); - expect(nikeAccountAgeCoefficient.missingValueCoefficient).toBe(0); - expect(nikeCoefficients.subjectCategoricalCoefficients).toHaveLength(1); - const nikeGenderIdentityCoefficient = nikeCoefficients.subjectCategoricalCoefficients[0]; - expect(nikeGenderIdentityCoefficient.attributeKey).toBe('gender_identity'); - expect(nikeGenderIdentityCoefficient.missingValueCoefficient).toBe(2.3); - expect(nikeGenderIdentityCoefficient.valueCoefficients).toStrictEqual({ - female: 0.5, - male: -0.5, + + it('Fetches and populates bandit parameters', async () => { + await configurationRequestor.fetchAndStoreConfigurations(); + + expect(fetchSpy).toHaveBeenCalledTimes(2); // Once for UFC, another for bandits + + expect(flagStore.getKeys().length).toBeGreaterThanOrEqual(2); + expect(flagStore.get('banner_bandit_flag')).toBeDefined(); + expect(flagStore.get('cold_start_bandit')).toBeDefined(); + + expect(banditModelStore.getKeys().length).toBeGreaterThanOrEqual(2); + + const bannerBandit = banditModelStore.get('banner_bandit'); + expect(bannerBandit?.banditKey).toBe('banner_bandit'); + expect(bannerBandit?.modelName).toBe('falcon'); + expect(bannerBandit?.modelVersion).toBe('123'); + const bannerModelData = bannerBandit?.modelData; + expect(bannerModelData?.gamma).toBe(1); + expect(bannerModelData?.defaultActionScore).toBe(0); + expect(bannerModelData?.actionProbabilityFloor).toBe(0); + const bannerCoefficients = bannerModelData?.coefficients || {}; + expect(Object.keys(bannerCoefficients).length).toBe(2); + + // Deep dive for the nike action + const nikeCoefficients = bannerCoefficients['nike']; + expect(nikeCoefficients.actionKey).toBe('nike'); + expect(nikeCoefficients.intercept).toBe(1); + expect(nikeCoefficients.actionNumericCoefficients).toHaveLength(1); + const nikeBrandAffinityCoefficient = nikeCoefficients.actionNumericCoefficients[0]; + expect(nikeBrandAffinityCoefficient.attributeKey).toBe('brand_affinity'); + expect(nikeBrandAffinityCoefficient.coefficient).toBe(1); + expect(nikeBrandAffinityCoefficient.missingValueCoefficient).toBe(-0.1); + expect(nikeCoefficients.actionCategoricalCoefficients).toHaveLength(2); + const nikeLoyaltyTierCoefficient = nikeCoefficients.actionCategoricalCoefficients[0]; + expect(nikeLoyaltyTierCoefficient.attributeKey).toBe('loyalty_tier'); + expect(nikeLoyaltyTierCoefficient.missingValueCoefficient).toBe(0); + expect(nikeLoyaltyTierCoefficient.valueCoefficients).toStrictEqual({ + gold: 4.5, + silver: 3.2, + bronze: 1.9, + }); + expect(nikeCoefficients.subjectNumericCoefficients).toHaveLength(1); + const nikeAccountAgeCoefficient = nikeCoefficients.subjectNumericCoefficients[0]; + expect(nikeAccountAgeCoefficient.attributeKey).toBe('account_age'); + expect(nikeAccountAgeCoefficient.coefficient).toBe(0.3); + expect(nikeAccountAgeCoefficient.missingValueCoefficient).toBe(0); + expect(nikeCoefficients.subjectCategoricalCoefficients).toHaveLength(1); + const nikeGenderIdentityCoefficient = nikeCoefficients.subjectCategoricalCoefficients[0]; + expect(nikeGenderIdentityCoefficient.attributeKey).toBe('gender_identity'); + expect(nikeGenderIdentityCoefficient.missingValueCoefficient).toBe(2.3); + expect(nikeGenderIdentityCoefficient.valueCoefficients).toStrictEqual({ + female: 0.5, + male: -0.5, + }); + + // Just spot check the adidas parameters + expect(bannerCoefficients['adidas'].subjectNumericCoefficients).toHaveLength(0); + expect( + bannerCoefficients['adidas'].subjectCategoricalCoefficients[0].valueCoefficients[ + 'female' + ], + ).toBe(0); + + const coldStartBandit = banditModelStore.get('cold_start_bandit'); + expect(coldStartBandit?.banditKey).toBe('cold_start_bandit'); + expect(coldStartBandit?.modelName).toBe('falcon'); + expect(coldStartBandit?.modelVersion).toBe('cold start'); + const coldStartModelData = coldStartBandit?.modelData; + expect(coldStartModelData?.gamma).toBe(1); + expect(coldStartModelData?.defaultActionScore).toBe(0); + expect(coldStartModelData?.actionProbabilityFloor).toBe(0); + expect(coldStartModelData?.coefficients).toStrictEqual({}); }); - // Just spot check the adidas parameters - expect(bannerCoefficients['adidas'].subjectNumericCoefficients).toHaveLength(0); - expect( - bannerCoefficients['adidas'].subjectCategoricalCoefficients[0].valueCoefficients['female'], - ).toBe(0); - - const coldStartBandit = banditModelStore.get('cold_start_bandit'); - expect(coldStartBandit?.banditKey).toBe('cold_start_bandit'); - expect(coldStartBandit?.modelName).toBe('falcon'); - expect(coldStartBandit?.modelVersion).toBe('cold start'); - const coldStartModelData = coldStartBandit?.modelData; - expect(coldStartModelData?.gamma).toBe(1); - expect(coldStartModelData?.defaultActionScore).toBe(0); - expect(coldStartModelData?.actionProbabilityFloor).toBe(0); - expect(coldStartModelData?.coefficients).toStrictEqual({}); - }); + it('Will not fetch bandit parameters if there is no store', async () => { + configurationRequestor = new ConfigurationRequestor(httpClient, flagStore, null, null); + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(1); + }); - it('Will not fetch bandit parameters if there is no store', async () => { - configurationRequestor = new ConfigurationRequestor(httpClient, flagStore, null, null); - await configurationRequestor.fetchAndStoreConfigurations(); - expect(fetchSpy).toHaveBeenCalledTimes(1); + it('Should not fetch bandits if model version is un-changed', async () => { + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(2); // Once for UFC, another for bandits + + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(3); // Once just for UFC, bandits should be skipped + }); + + describe('Bandits polling', () => { + const warmStartBanditReference = { + modelVersion: 'warm start', + flagVariations: [ + { + key: 'warm_start_bandit', + flagKey: 'warm_start_bandit_flag', + variationKey: 'warm_start_bandit', + variationValue: 'warm_start_bandit', + }, + ], + }; + + const warmStartBanditParameters = { + banditKey: 'warm_start_bandit', + modelName: 'pigeon', + modelVersion: 'warm start', + modelData: { + gamma: 1.0, + defaultActionScore: 0.0, + actionProbabilityFloor: 0.0, + coefficients: {}, + }, + }; + + const coldStartBanditParameters = { + banditKey: 'cold_start_bandit', + modelName: 'falcon', + modelVersion: 'cold start', + modelData: { + gamma: 1.0, + defaultActionScore: 0.0, + actionProbabilityFloor: 0.0, + coefficients: {}, + }, + }; + + afterAll(() => { + initiateFetchSpy(defaultResponseMockGenerator); + }); + + function expectBanditToBeInModelStore( + store: IConfigurationStore, + banditKey: string, + expectedBanditParameters: BanditParameters, + ) { + const bandit = store.get(banditKey); + expect(bandit).toBeTruthy(); + expect(bandit?.banditKey).toBe(expectedBanditParameters.banditKey); + expect(bandit?.modelVersion).toBe(expectedBanditParameters.modelVersion); + expect(bandit?.modelName).toBe(expectedBanditParameters.modelName); + expect(bandit?.modelData.gamma).toBe(expectedBanditParameters.modelData.gamma); + expect(bandit?.modelData.defaultActionScore).toBe( + expectedBanditParameters.modelData.defaultActionScore, + ); + expect(bandit?.modelData.actionProbabilityFloor).toBe( + expectedBanditParameters.modelData.actionProbabilityFloor, + ); + expect(bandit?.modelData.coefficients).toStrictEqual( + expectedBanditParameters.modelData.coefficients, + ); + } + + function injectWarmStartBanditToResponseByUrl( + url: string, + response: IUniversalFlagConfigResponse | IBanditParametersResponse, + ) { + if (url.includes('config') && 'banditReferences' in response) { + response.banditReferences.warm_start_bandit = warmStartBanditReference; + } + + if (url.includes('bandits') && 'bandits' in response) { + response.bandits.warm_start_bandit = warmStartBanditParameters; + } + } + + it('Should fetch bandits if new bandit references model versions appeared', async () => { + let updateUFC = false; + await configurationRequestor.fetchAndStoreConfigurations(); + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(3); + + const customResponseMockGenerator = (url: string) => { + const responseFile = url.includes('bandits') + ? MOCK_BANDIT_MODELS_RESPONSE_FILE + : MOCK_FLAGS_WITH_BANDITS_RESPONSE_FILE; + + const response = readMockUFCResponse(responseFile); + + if (updateUFC === true) { + injectWarmStartBanditToResponseByUrl(url, response); + } + return response; + }; + updateUFC = true; + initiateFetchSpy(customResponseMockGenerator); + + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(2); // 2 because fetchSpy was re-initiated, 1UFC and 1bandits + + // let's check if warm start was hydrated properly! + expectBanditToBeInModelStore( + banditModelStore, + 'warm_start_bandit', + warmStartBanditParameters, + ); + }); + + it('Should not fetch bandits if bandit references model versions shrunk', async () => { + // Initial fetch + await configurationRequestor.fetchAndStoreConfigurations(); + + // Let's mock UFC response so that cold_start is no longer retrieved + const customResponseMockGenerator = (url: string) => { + const responseFile = url.includes('bandits') + ? MOCK_BANDIT_MODELS_RESPONSE_FILE + : MOCK_FLAGS_WITH_BANDITS_RESPONSE_FILE; + + const response = readMockUFCResponse(responseFile); + + if (url.includes('config') && 'banditReferences' in response) { + delete response.banditReferences.cold_start_bandit; + } + return response; + }; + + initiateFetchSpy(customResponseMockGenerator); + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(1); // only once for UFC + + // cold start should still be in memory + expectBanditToBeInModelStore( + banditModelStore, + 'cold_start_bandit', + coldStartBanditParameters, + ); + }); + + /** + * 1. initial call - 1 fetch for ufc 1 for bandits + * 2. 2nd call - 1 fetch for ufc only; bandits unchanged + * 3. 3rd call - new bandit ref injected to UFC; 2 fetches, because new bandit appeared + * 4. 4th call - we remove a bandit from ufc; 1 fetch because there is no need to update. + * The bandit removed from UFC should still be in memory. + **/ + it('should fetch bandits based on banditReference change in UFC', async () => { + let injectWarmStart = false; + let removeColdStartBandit = false; + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(2); + + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(3); + + const customResponseMockGenerator = (url: string) => { + const responseFile = url.includes('bandits') + ? MOCK_BANDIT_MODELS_RESPONSE_FILE + : MOCK_FLAGS_WITH_BANDITS_RESPONSE_FILE; + const response = readMockUFCResponse(responseFile); + if (injectWarmStart === true) { + injectWarmStartBanditToResponseByUrl(url, response); + } else if ( + removeColdStartBandit === true && + 'banditReferences' in response && + url.includes('config') + ) { + delete response.banditReferences.cold_start_bandit; + } + return response; + }; + injectWarmStart = true; + initiateFetchSpy(customResponseMockGenerator); + + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(2); + expectBanditToBeInModelStore( + banditModelStore, + 'warm_start_bandit', + warmStartBanditParameters, + ); + + injectWarmStart = false; + removeColdStartBandit = true; + initiateFetchSpy(customResponseMockGenerator); + await configurationRequestor.fetchAndStoreConfigurations(); + expect(fetchSpy).toHaveBeenCalledTimes(1); + + expectBanditToBeInModelStore( + banditModelStore, + 'cold_start_bandit', + coldStartBanditParameters, + ); + }); + }); }); }); }); diff --git a/src/configuration-requestor.ts b/src/configuration-requestor.ts index 763cb06..543c10b 100644 --- a/src/configuration-requestor.ts +++ b/src/configuration-requestor.ts @@ -1,10 +1,19 @@ import { IConfigurationStore } from './configuration-store/configuration-store'; import { hydrateConfigurationStore } from './configuration-store/configuration-store-utils'; import { IHttpClient } from './http-client'; -import { BanditVariation, BanditParameters, Flag } from './interfaces'; +import { + BanditVariation, + BanditParameters, + Flag, + BanditReference, +} from './interfaces'; + +type Entry = Flag | BanditVariation[] | BanditParameters; // Requests AND stores flag configurations export default class ConfigurationRequestor { + private banditModelVersions: string[] = []; + constructor( private readonly httpClient: IHttpClient, private readonly flagConfigurationStore: IConfigurationStore, @@ -27,13 +36,13 @@ export default class ConfigurationRequestor { format: configResponse.format, }); - const flagsHaveBandits = Object.keys(configResponse.bandits ?? {}).length > 0; + const flagsHaveBandits = Object.keys(configResponse.banditReferences ?? {}).length > 0; const banditStoresProvided = Boolean( this.banditVariationConfigurationStore && this.banditModelConfigurationStore, ); if (flagsHaveBandits && banditStoresProvided) { // Map bandit flag associations by flag key for quick lookup (instead of bandit key as provided by the UFC) - const banditVariations = this.indexBanditVariationsByFlagKey(configResponse.bandits); + const banditVariations = this.indexBanditVariationsByFlagKey(configResponse.banditReferences); await hydrateConfigurationStore(this.banditVariationConfigurationStore, { entries: banditVariations, @@ -42,29 +51,58 @@ export default class ConfigurationRequestor { format: configResponse.format, }); - // TODO: different polling intervals for bandit parameters - const banditResponse = await this.httpClient.getBanditParameters(); - if (banditResponse?.bandits) { - if (!this.banditModelConfigurationStore) { - throw new Error('Bandit parameters fetched but no bandit configuration store provided'); - } + if ( + this.requiresBanditModelConfigurationStoreUpdate( + this.banditModelVersions, + configResponse.banditReferences, + ) + ) { + const banditResponse = await this.httpClient.getBanditParameters(); + if (banditResponse?.bandits) { + await hydrateConfigurationStore(this.banditModelConfigurationStore, { + entries: banditResponse.bandits, + environment: configResponse.environment, + createdAt: configResponse.createdAt, + format: configResponse.format,}); - await hydrateConfigurationStore(this.banditModelConfigurationStore, { - entries: banditResponse.bandits, - environment: configResponse.environment, - createdAt: configResponse.createdAt, - format: configResponse.format, - }); + this.banditModelVersions = this.getLoadedBanditModelVersionsFromStore( + this.banditModelConfigurationStore, + ); + } } } } + private getLoadedBanditModelVersionsFromStore( + banditModelConfigurationStore: IConfigurationStore | null, + ): string[] { + if (banditModelConfigurationStore === null) { + return []; + } + return Object.values(banditModelConfigurationStore.entries()).map( + (banditParam: BanditParameters) => banditParam.modelVersion, + ); + } + + private requiresBanditModelConfigurationStoreUpdate( + currentBanditModelVersions: string[], + banditReferences: Record, + ): boolean { + const referencedModelVersions = Object.values(banditReferences).map( + (banditReference: BanditReference) => banditReference.modelVersion, + ); + + return !referencedModelVersions.every((modelVersion) => + currentBanditModelVersions.includes(modelVersion), + ); + } + private indexBanditVariationsByFlagKey( - banditVariationsByBanditKey: Record, + banditVariationsByBanditKey: Record, ): Record { const banditVariationsByFlagKey: Record = {}; - Object.values(banditVariationsByBanditKey).forEach((banditVariations) => { - banditVariations.forEach((banditVariation) => { + Object.values(banditVariationsByBanditKey).forEach((banditReference) => { + banditReference.flagVariations.forEach((banditVariation) => { let banditVariations = banditVariationsByFlagKey[banditVariation.flagKey]; if (!banditVariations) { banditVariations = []; diff --git a/src/http-client.ts b/src/http-client.ts index de3f636..a5186fa 100644 --- a/src/http-client.ts +++ b/src/http-client.ts @@ -2,7 +2,7 @@ import ApiEndpoints from './api-endpoints'; import { IPrecomputedConfigurationResponse } from './configuration'; import { BanditParameters, - BanditVariation, + BanditReference, Environment, Flag, FormatEnum, @@ -35,7 +35,7 @@ export interface IUniversalFlagConfigResponse { format: FormatEnum; environment: Environment; flags: Record; - bandits: Record; + banditReferences: Record; } export interface IBanditParametersResponse { diff --git a/src/interfaces.ts b/src/interfaces.ts index 3d7f161..5b40d98 100644 --- a/src/interfaces.ts +++ b/src/interfaces.ts @@ -102,6 +102,11 @@ export interface BanditVariation { variationValue: string; } +export interface BanditReference { + modelVersion: string; + flagVariations: BanditVariation[]; +} + export interface BanditParameters { banditKey: string; modelName: string;