Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/snaps-controllers/coverage.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"branches": 93.34,
"functions": 97.37,
"lines": 98.33,
"statements": 98.06
"branches": 93.36,
"functions": 97.38,
"lines": 98.34,
"statements": 98.07
}
98 changes: 69 additions & 29 deletions packages/snaps-controllers/src/snaps/SnapController.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ import { pipeline } from 'readable-stream';
import type { Duplex } from 'readable-stream';
import { inc } from 'semver';

import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants';
import {
LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS,
STATE_DEBOUNCE_TIMEOUT,
} from './constants';
import { SnapsRegistryStatus } from './registry';
import type { SnapControllerState } from './SnapController';
import {
Expand Down Expand Up @@ -9228,6 +9231,14 @@ describe('SnapController', () => {
});

describe('SnapController:getSnapState', () => {
beforeAll(() => {
jest.useFakeTimers();
});

afterAll(() => {
jest.useRealTimers();
});

it(`gets the snap's state`, async () => {
const messenger = getSnapControllerMessenger();

Expand Down Expand Up @@ -9316,6 +9327,7 @@ describe('SnapController', () => {
DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS,
);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

const result = await messenger.call(
Expand Down Expand Up @@ -9367,6 +9379,7 @@ describe('SnapController', () => {
true,
);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

const encryptedState1 = await encrypt(
Expand Down Expand Up @@ -9557,6 +9570,14 @@ describe('SnapController', () => {
});

describe('SnapController:updateSnapState', () => {
beforeAll(() => {
jest.useFakeTimers();
});

afterAll(() => {
jest.useRealTimers();
});

it(`updates the snap's state`, async () => {
const messenger = getSnapControllerMessenger();

Expand Down Expand Up @@ -9587,6 +9608,7 @@ describe('SnapController', () => {
true,
);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

expect(updateSnapStateSpy).toHaveBeenCalledTimes(1);
Expand All @@ -9611,6 +9633,8 @@ describe('SnapController', () => {

const updateSnapStateSpy = jest.spyOn(snapController, 'updateSnapState');
const state = { foo: 'bar' };

const promise = waitForStateChange(messenger);
await messenger.call(
'SnapController:updateSnapState',
MOCK_SNAP_ID,
Expand All @@ -9619,6 +9643,10 @@ describe('SnapController', () => {
);

expect(updateSnapStateSpy).toHaveBeenCalledTimes(1);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

expect(
snapController.state.unencryptedSnapStates[MOCK_SNAP_ID],
).toStrictEqual(JSON.stringify(state));
Expand Down Expand Up @@ -9657,30 +9685,19 @@ describe('SnapController', () => {
true,
);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

expect(hmacSha512).toHaveBeenCalledTimes(10);

snapController.destroy();
});

it('queues multiple state updates', async () => {
it('debounces multiple state updates', async () => {
const messenger = getSnapControllerMessenger();

jest.useFakeTimers();

const encryptor = getSnapControllerEncryptor();
const { promise, resolve } = createDeferredPromise();
const encryptWithKey = jest
.fn<
ReturnType<typeof encryptor.encryptWithKey>,
Parameters<typeof encryptor.encryptWithKey>
>()
.mockImplementation(async (...args) => {
resolve();
await sleep(1);
return await encryptor.encryptWithKey(...args);
});
const encryptWithKey = jest.spyOn(encryptor, 'encryptWithKey');

const snapController = getSnapController(
getSnapControllerOptions({
Expand All @@ -9696,41 +9713,52 @@ describe('SnapController', () => {
}),
);

const firstStateChange = waitForStateChange(messenger);
const promise = waitForStateChange(messenger);
await messenger.call(
'SnapController:updateSnapState',
MOCK_SNAP_ID,
{ foo: 'bar' },
true,
);

expect(
await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true),
).toStrictEqual({ foo: 'bar' });

await messenger.call(
'SnapController:updateSnapState',
MOCK_SNAP_ID,
{ bar: 'baz' },
true,
);

// We await this promise to ensure the timer is queued.
await promise;
jest.advanceTimersByTime(1);
expect(
await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true),
).toStrictEqual({ bar: 'baz' });

// After this point the second update should be queued.
await firstStateChange;
const secondStateChange = waitForStateChange(messenger);
expect(encryptWithKey).not.toHaveBeenCalled();

expect(encryptWithKey).toHaveBeenCalledTimes(1);
jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

// This is a bit hacky, but we can't simply advance the timer by 1ms
// because the second timer is not running yet.
jest.useRealTimers();
await secondStateChange;
expect(encryptWithKey).toHaveBeenCalledTimes(1);

expect(encryptWithKey).toHaveBeenCalledTimes(2);
const nextStateChange = waitForStateChange(messenger);
await messenger.call(
'SnapController:updateSnapState',
MOCK_SNAP_ID,
{ qux: 'quux' },
true,
);

expect(
await messenger.call('SnapController:getSnapState', MOCK_SNAP_ID, true),
).toStrictEqual({ bar: 'baz' });
).toStrictEqual({ qux: 'quux' });

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await nextStateChange;

expect(encryptWithKey).toHaveBeenCalledTimes(2);

snapController.destroy();
});
Expand Down Expand Up @@ -9763,14 +9791,24 @@ describe('SnapController', () => {
true,
);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

expect(error).toHaveBeenCalledWith(errorValue);

snapController.destroy();
});
});

describe('SnapController:clearSnapState', () => {
beforeAll(() => {
jest.useFakeTimers();
});

afterAll(() => {
jest.useRealTimers();
});

it('clears the state of a snap', async () => {
const messenger = getSnapControllerMessenger();

Expand Down Expand Up @@ -9859,7 +9897,9 @@ describe('SnapController', () => {
// eslint-disable-next-line @typescript-eslint/await-thenable
await messenger.call('SnapController:clearSnapState', MOCK_SNAP_ID, true);

jest.advanceTimersByTime(STATE_DEBOUNCE_TIMEOUT);
await promise;

expect(error).toHaveBeenCalledWith(errorValue);

snapController.destroy();
Expand Down
99 changes: 53 additions & 46 deletions packages/snaps-controllers/src/snaps/SnapController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ import { gt } from 'semver';
import {
ALLOWED_PERMISSIONS,
LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS,
STATE_DEBOUNCE_TIMEOUT,
} from './constants';
import type { SnapLocation } from './location';
import { detectSnapLocation } from './location';
Expand Down Expand Up @@ -167,6 +168,7 @@ import type {
KeyDerivationOptions,
} from '../types';
import {
debouncePersistState,
fetchSnap,
hasTimedOut,
permissionsDiff,
Expand Down Expand Up @@ -1798,6 +1800,23 @@ export class SnapController extends BaseController<
return truncateSnap(this.getExpect(snapId));
}

/**
* Check if a given Snap has a cached encryption key stored in the runtime.
*
* @param snapId - The Snap ID.
* @param runtime - The Snap runtime data.
* @returns True if the Snap has a cached encryption key, otherwise false.
*/
#hasCachedEncryptionKey(
snapId: SnapId,
runtime = this.#getRuntimeExpect(snapId),
): runtime is SnapRuntimeData & {
encryptionKey: string;
encryptionSalt: string;
} {
return runtime.encryptionKey !== null && runtime.encryptionSalt !== null;
}

/**
* Generate an encryption key to be used for state encryption for a given Snap.
*
Expand All @@ -1821,7 +1840,7 @@ export class SnapController extends BaseController<
}): Promise<{ key: unknown; salt: string }> {
const runtime = this.#getRuntimeExpect(snapId);

if (runtime.encryptionKey && runtime.encryptionSalt && useCache) {
if (this.#hasCachedEncryptionKey(snapId, runtime) && useCache) {
return {
key: await this.#encryptor.importKey(runtime.encryptionKey),
salt: runtime.encryptionSalt,
Expand Down Expand Up @@ -1853,17 +1872,6 @@ export class SnapController extends BaseController<
return { key: encryptionKey, salt };
}

/**
* Check if a given Snap has a cached encryption key stored in the runtime.
*
* @param snapId - The Snap ID.
* @returns True if the Snap has a cached encryption key, otherwise false.
*/
#hasCachedEncryptionKey(snapId: SnapId) {
const runtime = this.#getRuntimeExpect(snapId);
return runtime.encryptionKey !== null && runtime.encryptionSalt !== null;
}

/**
* Decrypt the encrypted state for a given Snap.
*
Expand Down Expand Up @@ -1958,38 +1966,45 @@ export class SnapController extends BaseController<
/**
* Persist the state of a Snap.
*
* This is run with a mutex to ensure that only one state update per Snap is
* processed at a time, avoiding possible race conditions.
* This function is debounced per Snap, meaning that multiple calls to this
* function for the same Snap will only result in one state update. It also
* uses a mutex to ensure that only one state update per Snap is processed at
* a time, avoiding possible race conditions.
*
* @param snapId - The Snap ID.
* @param newSnapState - The new state of the Snap.
* @param encrypted - A flag to indicate whether to use encrypted storage or
* not.
*/
async #persistSnapState(
snapId: SnapId,
newSnapState: Record<string, Json> | null,
encrypted: boolean,
) {
const runtime = this.#getRuntimeExpect(snapId);
await runtime.stateMutex.runExclusive(async () => {
const newState = await this.#getStateToPersist(
snapId,
newSnapState,
encrypted,
);
readonly #persistSnapState = debouncePersistState(
(
snapId: SnapId,
newSnapState: Record<string, Json> | null,
encrypted: boolean,
) => {
const runtime = this.#getRuntimeExpect(snapId);
runtime.stateMutex
.runExclusive(async () => {
const newState = await this.#getStateToPersist(
snapId,
newSnapState,
encrypted,
);

if (encrypted) {
return this.update((state) => {
state.snapStates[snapId] = newState;
});
}
if (encrypted) {
return this.update((state) => {
state.snapStates[snapId] = newState;
});
}

return this.update((state) => {
state.unencryptedSnapStates[snapId] = newState;
});
});
}
return this.update((state) => {
state.unencryptedSnapStates[snapId] = newState;
});
})
.catch(logError);
},
STATE_DEBOUNCE_TIMEOUT,
);

/**
* Updates the own state of the snap with the given id.
Expand All @@ -2012,11 +2027,7 @@ export class SnapController extends BaseController<
runtime.unencryptedState = newSnapState;
}

// This is intentionally run asynchronously to avoid blocking the main
// thread.
this.#persistSnapState(snapId, newSnapState, encrypted).catch((error) => {
logError(error);
});
this.#persistSnapState(snapId, newSnapState, encrypted);
}

/**
Expand All @@ -2034,11 +2045,7 @@ export class SnapController extends BaseController<
runtime.unencryptedState = null;
}

// This is intentionally run asynchronously to avoid blocking the main
// thread.
this.#persistSnapState(snapId, null, encrypted).catch((error) => {
logError(error);
});
this.#persistSnapState(snapId, null, encrypted);
}

/**
Expand Down
Loading
Loading