diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json index 6d44c6623f..38d5b176c1 100644 --- a/packages/snaps-controllers/coverage.json +++ b/packages/snaps-controllers/coverage.json @@ -1,6 +1,6 @@ { - "branches": 93.34, - "functions": 97.37, - "lines": 98.33, - "statements": 98.06 + "branches": 93.36, + "functions": 97.38, + "lines": 98.34, + "statements": 98.07 } diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx index d0b5e393ea..9242eb02f6 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx +++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx @@ -71,7 +71,10 @@ import { pipeline } from 'readable-stream'; import type { Duplex } from 'readable-stream'; import { inc } from 'semver'; -import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants'; +import { + LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS, + STATE_DEBOUNCE_TIMEOUT, +} from './constants'; import { SnapsRegistryStatus } from './registry'; import type { SnapControllerState } from './SnapController'; import { @@ -9228,6 +9231,14 @@ describe('SnapController', () => { }); describe('SnapController:getSnapState', () => { + beforeAll(() => { + jest.useFakeTimers(); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + it(`gets the snap's state`, async () => { const messenger = getSnapControllerMessenger(); @@ -9316,6 +9327,7 @@ describe('SnapController', () => { DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS, ); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; const result = await messenger.call( @@ -9367,6 +9379,7 @@ describe('SnapController', () => { true, ); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; const encryptedState1 = await encrypt( @@ -9557,6 +9570,14 @@ describe('SnapController', () => { }); describe('SnapController:updateSnapState', () => { + beforeAll(() => { + jest.useFakeTimers(); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + it(`updates the snap's state`, async () => { const messenger = getSnapControllerMessenger(); @@ -9587,6 +9608,7 @@ describe('SnapController', () => { true, ); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; expect(updateSnapStateSpy).toHaveBeenCalledTimes(1); @@ -9611,6 +9633,8 @@ describe('SnapController', () => { const updateSnapStateSpy = jest.spyOn(snapController, 'updateSnapState'); const state = { foo: 'bar' }; + + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9619,6 +9643,10 @@ describe('SnapController', () => { ); expect(updateSnapStateSpy).toHaveBeenCalledTimes(1); + + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); + await promise; + expect( snapController.state.unencryptedSnapStates[MOCK_SNAP_ID], ).toStrictEqual(JSON.stringify(state)); @@ -9657,6 +9685,7 @@ describe('SnapController', () => { true, ); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; expect(hmacSha512).toHaveBeenCalledTimes(10); @@ -9664,23 +9693,11 @@ describe('SnapController', () => { snapController.destroy(); }); - it('queues multiple state updates', async () => { + it('debounces multiple state updates', async () => { const messenger = getSnapControllerMessenger(); - jest.useFakeTimers(); - const encryptor = getSnapControllerEncryptor(); - const { promise, resolve } = createDeferredPromise(); - const encryptWithKey = jest - .fn< - ReturnType, - Parameters - >() - .mockImplementation(async (...args) => { - resolve(); - await sleep(1); - return await encryptor.encryptWithKey(...args); - }); + const encryptWithKey = jest.spyOn(encryptor, 'encryptWithKey'); const snapController = getSnapController( getSnapControllerOptions({ @@ -9696,7 +9713,7 @@ describe('SnapController', () => { }), ); - const firstStateChange = waitForStateChange(messenger); + const promise = waitForStateChange(messenger); await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9704,6 +9721,10 @@ describe('SnapController', () => { true, ); + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual({ foo: 'bar' }); + await messenger.call( 'SnapController:updateSnapState', MOCK_SNAP_ID, @@ -9711,26 +9732,33 @@ describe('SnapController', () => { true, ); - // We await this promise to ensure the timer is queued. - await promise; - jest.advanceTimersByTime(1); + expect( + await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), + ).toStrictEqual({ bar: 'baz' }); - // After this point the second update should be queued. - await firstStateChange; - const secondStateChange = waitForStateChange(messenger); + expect(encryptWithKey).not.toHaveBeenCalled(); - expect(encryptWithKey).toHaveBeenCalledTimes(1); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); + await promise; - // This is a bit hacky, but we can't simply advance the timer by 1ms - // because the second timer is not running yet. - jest.useRealTimers(); - await secondStateChange; + expect(encryptWithKey).toHaveBeenCalledTimes(1); - expect(encryptWithKey).toHaveBeenCalledTimes(2); + const nextStateChange = waitForStateChange(messenger); + await messenger.call( + 'SnapController:updateSnapState', + MOCK_SNAP_ID, + { qux: 'quux' }, + true, + ); expect( await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true), - ).toStrictEqual({ bar: 'baz' }); + ).toStrictEqual({ qux: 'quux' }); + + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); + await nextStateChange; + + expect(encryptWithKey).toHaveBeenCalledTimes(2); snapController.destroy(); }); @@ -9763,7 +9791,9 @@ describe('SnapController', () => { true, ); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; + expect(error).toHaveBeenCalledWith(errorValue); snapController.destroy(); @@ -9771,6 +9801,14 @@ describe('SnapController', () => { }); describe('SnapController:clearSnapState', () => { + beforeAll(() => { + jest.useFakeTimers(); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + it('clears the state of a snap', async () => { const messenger = getSnapControllerMessenger(); @@ -9859,7 +9897,9 @@ describe('SnapController', () => { // eslint-disable-next-line @typescript-eslint/await-thenable await messenger.call('SnapController:clearSnapState', MOCK_SNAP_ID, true); + jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT); await promise; + expect(error).toHaveBeenCalledWith(errorValue); snapController.destroy(); diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts index 90cb50cc76..32c26bfd4d 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.ts +++ b/packages/snaps-controllers/src/snaps/SnapController.ts @@ -134,6 +134,7 @@ import { gt } from 'semver'; import { ALLOWED_PERMISSIONS, LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS, + STATE_DEBOUNCE_TIMEOUT, } from './constants'; import type { SnapLocation } from './location'; import { detectSnapLocation } from './location'; @@ -167,6 +168,7 @@ import type { KeyDerivationOptions, } from '../types'; import { + debouncePersistState, fetchSnap, hasTimedOut, permissionsDiff, @@ -1798,6 +1800,23 @@ export class SnapController extends BaseController< return truncateSnap(this.getExpect(snapId)); } + /** + * Check if a given Snap has a cached encryption key stored in the runtime. + * + * @param snapId - The Snap ID. + * @param runtime - The Snap runtime data. + * @returns True if the Snap has a cached encryption key, otherwise false. + */ + #hasCachedEncryptionKey( + snapId: SnapId, + runtime = this.#getRuntimeExpect(snapId), + ): runtime is SnapRuntimeData & { + encryptionKey: string; + encryptionSalt: string; + } { + return runtime.encryptionKey !== null && runtime.encryptionSalt !== null; + } + /** * Generate an encryption key to be used for state encryption for a given Snap. * @@ -1821,7 +1840,7 @@ export class SnapController extends BaseController< }): Promise<{ key: unknown; salt: string }> { const runtime = this.#getRuntimeExpect(snapId); - if (runtime.encryptionKey && runtime.encryptionSalt && useCache) { + if (this.#hasCachedEncryptionKey(snapId, runtime) && useCache) { return { key: await this.#encryptor.importKey(runtime.encryptionKey), salt: runtime.encryptionSalt, @@ -1853,17 +1872,6 @@ export class SnapController extends BaseController< return { key: encryptionKey, salt }; } - /** - * Check if a given Snap has a cached encryption key stored in the runtime. - * - * @param snapId - The Snap ID. - * @returns True if the Snap has a cached encryption key, otherwise false. - */ - #hasCachedEncryptionKey(snapId: SnapId) { - const runtime = this.#getRuntimeExpect(snapId); - return runtime.encryptionKey !== null && runtime.encryptionSalt !== null; - } - /** * Decrypt the encrypted state for a given Snap. * @@ -1958,38 +1966,45 @@ export class SnapController extends BaseController< /** * Persist the state of a Snap. * - * This is run with a mutex to ensure that only one state update per Snap is - * processed at a time, avoiding possible race conditions. + * This function is debounced per Snap, meaning that multiple calls to this + * function for the same Snap will only result in one state update. It also + * uses a mutex to ensure that only one state update per Snap is processed at + * a time, avoiding possible race conditions. * * @param snapId - The Snap ID. * @param newSnapState - The new state of the Snap. * @param encrypted - A flag to indicate whether to use encrypted storage or * not. */ - async #persistSnapState( - snapId: SnapId, - newSnapState: Record | null, - encrypted: boolean, - ) { - const runtime = this.#getRuntimeExpect(snapId); - await runtime.stateMutex.runExclusive(async () => { - const newState = await this.#getStateToPersist( - snapId, - newSnapState, - encrypted, - ); + readonly #persistSnapState = debouncePersistState( + ( + snapId: SnapId, + newSnapState: Record | null, + encrypted: boolean, + ) => { + const runtime = this.#getRuntimeExpect(snapId); + runtime.stateMutex + .runExclusive(async () => { + const newState = await this.#getStateToPersist( + snapId, + newSnapState, + encrypted, + ); - if (encrypted) { - return this.update((state) => { - state.snapStates[snapId] = newState; - }); - } + if (encrypted) { + return this.update((state) => { + state.snapStates[snapId] = newState; + }); + } - return this.update((state) => { - state.unencryptedSnapStates[snapId] = newState; - }); - }); - } + return this.update((state) => { + state.unencryptedSnapStates[snapId] = newState; + }); + }) + .catch(logError); + }, + STATE_DEBOUNCE_TIMEOUT, + ); /** * Updates the own state of the snap with the given id. @@ -2012,11 +2027,7 @@ export class SnapController extends BaseController< runtime.unencryptedState = newSnapState; } - // This is intentionally run asynchronously to avoid blocking the main - // thread. - this.#persistSnapState(snapId, newSnapState, encrypted).catch((error) => { - logError(error); - }); + this.#persistSnapState(snapId, newSnapState, encrypted); } /** @@ -2034,11 +2045,7 @@ export class SnapController extends BaseController< runtime.unencryptedState = null; } - // This is intentionally run asynchronously to avoid blocking the main - // thread. - this.#persistSnapState(snapId, null, encrypted).catch((error) => { - logError(error); - }); + this.#persistSnapState(snapId, null, encrypted); } /** diff --git a/packages/snaps-controllers/src/snaps/constants.ts b/packages/snaps-controllers/src/snaps/constants.ts index 2125d9976d..dfaaa88ad3 100644 --- a/packages/snaps-controllers/src/snaps/constants.ts +++ b/packages/snaps-controllers/src/snaps/constants.ts @@ -20,3 +20,8 @@ export const LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS = { iterations: 10_000, }, }; + +/** + * The timeout for debouncing state updates. + */ +export const STATE_DEBOUNCE_TIMEOUT = 500; diff --git a/packages/snaps-controllers/src/utils.test.ts b/packages/snaps-controllers/src/utils.test.ts index 26e1a802cf..7899c2a773 100644 --- a/packages/snaps-controllers/src/utils.test.ts +++ b/packages/snaps-controllers/src/utils.test.ts @@ -2,6 +2,8 @@ import { VirtualFile } from '@metamask/snaps-utils'; import { getMockSnapFiles, getSnapManifest, + MOCK_LOCAL_SNAP_ID, + MOCK_SNAP_ID, } from '@metamask/snaps-utils/test-utils'; import { assert } from '@metamask/utils'; @@ -13,7 +15,12 @@ import { MOCK_RPC_ORIGINS_PERMISSION, MOCK_SNAP_DIALOG_PERMISSION, } from './test-utils'; -import { getSnapFiles, permissionsDiff, setDiff } from './utils'; +import { + debouncePersistState, + getSnapFiles, + permissionsDiff, + setDiff, +} from './utils'; import { SnapEndowments } from '../../snaps-rpc-methods/src/endowments'; describe('setDiff', () => { @@ -180,3 +187,37 @@ describe('getSnapFiles', () => { ]); }); }); + +describe('debouncePersistState', () => { + beforeAll(() => { + jest.useFakeTimers(); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + it('debounces persisting the state based on the Snap ID and whether it should be encrypted or not', () => { + const fn = jest.fn(); + const debounced = debouncePersistState(fn, 100); + + expect(debounced(MOCK_SNAP_ID, {}, true)).toBeUndefined(); + expect(debounced(MOCK_SNAP_ID, {}, true)).toBeUndefined(); + expect(debounced(MOCK_SNAP_ID, {}, false)).toBeUndefined(); + expect(debounced(MOCK_SNAP_ID, {}, false)).toBeUndefined(); + + expect(debounced(MOCK_LOCAL_SNAP_ID, {}, true)).toBeUndefined(); + expect(debounced(MOCK_LOCAL_SNAP_ID, {}, true)).toBeUndefined(); + expect(debounced(MOCK_LOCAL_SNAP_ID, {}, false)).toBeUndefined(); + expect(debounced(MOCK_LOCAL_SNAP_ID, {}, false)).toBeUndefined(); + + expect(fn).toHaveBeenCalledTimes(0); + + jest.advanceTimersByTime(100); + expect(fn).toHaveBeenCalledTimes(4); + expect(fn).toHaveBeenNthCalledWith(1, MOCK_SNAP_ID, {}, true); + expect(fn).toHaveBeenNthCalledWith(2, MOCK_SNAP_ID, {}, false); + expect(fn).toHaveBeenNthCalledWith(3, MOCK_LOCAL_SNAP_ID, {}, true); + expect(fn).toHaveBeenNthCalledWith(4, MOCK_LOCAL_SNAP_ID, {}, false); + }); +}); diff --git a/packages/snaps-controllers/src/utils.ts b/packages/snaps-controllers/src/utils.ts index 0398964e38..2292c0462f 100644 --- a/packages/snaps-controllers/src/utils.ts +++ b/packages/snaps-controllers/src/utils.ts @@ -8,6 +8,7 @@ import { validateAuxiliaryFiles, validateFetchedSnap, } from '@metamask/snaps-utils'; +import type { Json } from '@metamask/utils'; import deepEqual from 'fast-deep-equal'; import type { SnapLocation } from './snaps'; @@ -329,3 +330,48 @@ export async function fetchSnap(snapId: SnapId, location: SnapLocation) { ); } } + +/** + * Debounce persisting Snap state changes. + * + * @param fn - The function to debounce. + * @param timeout - The timeout in milliseconds. Defaults to 1000. + * @returns The debounced function. + * @example + * const originalUpdate = (snapId, newSnapState, encrypted) => { + * console.log(`Called with Snap ID: ${snapId} and state: ${newSnapState}`); + * }; + * + * const debouncedUpdate = debounce(originalUpdate); + * debouncedFunction('npm:foo-snap', { foo: 'bar' }, false); + */ +export function debouncePersistState( + fn: ( + snapId: SnapId, + newSnapState: Record | null, + encrypted: boolean, + ) => void, + timeout = 1000, +) { + const timeouts = new Map(); + + return ( + snapId: SnapId, + newSnapState: Record | null, + encrypted: boolean, + ): void => { + const key = `${snapId}-${encrypted}`; + + if (timeouts.has(key)) { + clearTimeout(timeouts.get(key)); + } + + timeouts.set( + key, + setTimeout(() => { + fn(snapId, newSnapState, encrypted); + timeouts.delete(key); + }, timeout), + ); + }; +}