Skip to content

Commit 9e4e4d8

Browse files
committed
Cache snap state in memory
1 parent 8b2bada commit 9e4e4d8

File tree

7 files changed

+186
-18
lines changed

7 files changed

+186
-18
lines changed

packages/snaps-controllers/src/snaps/SnapController.test.tsx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ import {
9797
MOCK_WALLET_SNAP_PERMISSION,
9898
MockSnapsRegistry,
9999
sleep,
100+
waitForStateChange,
100101
} from '../test-utils';
101102
import { delay } from '../utils';
102103
import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants';
@@ -8801,6 +8802,7 @@ describe('SnapController', () => {
88018802
);
88028803

88038804
const newState = { myVariable: 2 };
8805+
const promise = waitForStateChange(messenger);
88048806

88058807
await messenger.call(
88068808
'SnapController:updateSnapState',
@@ -8817,6 +8819,8 @@ describe('SnapController', () => {
88178819
DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS,
88188820
);
88198821

8822+
await promise;
8823+
88208824
const result = await messenger.call(
88218825
'SnapController:getSnapState',
88228826
MOCK_SNAP_ID,
@@ -8831,7 +8835,7 @@ describe('SnapController', () => {
88318835
snapController.destroy();
88328836
});
88338837

8834-
it('different snaps use different encryption keys', async () => {
8838+
it('uses different encryption keys for different snaps', async () => {
88358839
const messenger = getSnapControllerMessenger();
88368840

88378841
const state = { foo: 'bar' };
@@ -8857,13 +8861,17 @@ describe('SnapController', () => {
88578861
true,
88588862
);
88598863

8864+
const promise = waitForStateChange(messenger);
8865+
88608866
await messenger.call(
88618867
'SnapController:updateSnapState',
88628868
MOCK_LOCAL_SNAP_ID,
88638869
state,
88648870
true,
88658871
);
88668872

8873+
await promise;
8874+
88678875
const encryptedState1 = await encrypt(
88688876
ENCRYPTION_KEY,
88698877
state,
@@ -9073,13 +9081,17 @@ describe('SnapController', () => {
90739081
undefined,
90749082
DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS,
90759083
);
9084+
9085+
const promise = waitForStateChange(messenger);
90769086
await messenger.call(
90779087
'SnapController:updateSnapState',
90789088
MOCK_SNAP_ID,
90799089
state,
90809090
true,
90819091
);
90829092

9093+
await promise;
9094+
90839095
expect(updateSnapStateSpy).toHaveBeenCalledTimes(1);
90849096
expect(snapController.state.snapStates[MOCK_SNAP_ID]).toStrictEqual(
90859097
mockEncryptedState,
@@ -9137,13 +9149,17 @@ describe('SnapController', () => {
91379149
);
91389150

91399151
const state = { foo: 'bar' };
9152+
9153+
const promise = waitForStateChange(messenger);
91409154
await messenger.call(
91419155
'SnapController:updateSnapState',
91429156
MOCK_SNAP_ID,
91439157
state,
91449158
true,
91459159
);
91469160

9161+
await promise;
9162+
91479163
expect(pbkdf2Sha512).toHaveBeenCalledTimes(1);
91489164

91499165
snapController.destroy();

packages/snaps-controllers/src/snaps/SnapController.ts

Lines changed: 112 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import type {
6565
TruncatedSnapFields,
6666
} from '@metamask/snaps-utils';
6767
import {
68+
withMutex,
6869
logWarning,
6970
getPlatformVersion,
7071
assertIsSnapManifest,
@@ -252,6 +253,16 @@ export interface SnapRuntimeData {
252253
* A boolean flag to determine whether the Snap is currently being stopped.
253254
*/
254255
stopping: boolean;
256+
257+
/**
258+
* Cached encrypted state of the Snap.
259+
*/
260+
state: Record<string, Json> | null;
261+
262+
/**
263+
* Cached unencrypted state of the Snap.
264+
*/
265+
unencryptedState: Record<string, Json> | null;
255266
}
256267

257268
export type SnapError = {
@@ -906,6 +917,7 @@ export class SnapController extends BaseController<
906917
this._onOutboundResponse = this._onOutboundResponse.bind(this);
907918
this.#rollbackSnapshots = new Map();
908919
this.#snapsRuntimeData = new Map();
920+
909921
this.#pollForLastRequestStatus();
910922

911923
/* eslint-disable @typescript-eslint/unbound-method */
@@ -1860,6 +1872,68 @@ export class SnapController extends BaseController<
18601872
return JSON.stringify(encryptedState);
18611873
}
18621874

1875+
/**
1876+
* Get the new Snap state to persist based on the given state and encryption
1877+
* flag.
1878+
*
1879+
* - If the state is null, return null.
1880+
* - If the state should not be encrypted, return the JSON stringified state.
1881+
* - Otherwise if the state should be encrypted, return the encrypted state.
1882+
*
1883+
* @param snapId - The Snap ID.
1884+
* @param state - The state to persist.
1885+
* @param encrypted - A flag to indicate whether to use encrypted storage or
1886+
* not.
1887+
* @returns The state to persist.
1888+
*/
1889+
async #getStateToPersist(
1890+
snapId: SnapId,
1891+
state: Record<string, Json> | null,
1892+
encrypted: boolean,
1893+
) {
1894+
if (state === null) {
1895+
return null;
1896+
}
1897+
1898+
if (encrypted) {
1899+
return await this.#encryptSnapState(snapId, state);
1900+
}
1901+
1902+
return JSON.stringify(state);
1903+
}
1904+
1905+
/**
1906+
* Persist the state of a Snap.
1907+
*
1908+
* @param snapId - The Snap ID.
1909+
* @param newSnapState - The new state of the Snap.
1910+
* @param encrypted - A flag to indicate whether to use encrypted storage or
1911+
* not.
1912+
*/
1913+
#persistSnapState = withMutex(
1914+
async (
1915+
snapId: SnapId,
1916+
newSnapState: Record<string, Json> | null,
1917+
encrypted: boolean,
1918+
) => {
1919+
const newState = await this.#getStateToPersist(
1920+
snapId,
1921+
newSnapState,
1922+
encrypted,
1923+
);
1924+
1925+
if (encrypted) {
1926+
return this.update((state) => {
1927+
state.snapStates[snapId] = newState;
1928+
});
1929+
}
1930+
1931+
return this.update((state) => {
1932+
state.unencryptedSnapStates[snapId] = newState;
1933+
});
1934+
},
1935+
);
1936+
18631937
/**
18641938
* Updates the own state of the snap with the given id.
18651939
* This is distinct from the state MetaMask uses to manage snaps.
@@ -1873,17 +1947,19 @@ export class SnapController extends BaseController<
18731947
newSnapState: Record<string, Json>,
18741948
encrypted: boolean,
18751949
) {
1876-
if (encrypted) {
1877-
const encryptedState = await this.#encryptSnapState(snapId, newSnapState);
1950+
const runtime = this.#getRuntimeExpect(snapId);
18781951

1879-
this.update((state) => {
1880-
state.snapStates[snapId] = encryptedState;
1881-
});
1952+
if (encrypted) {
1953+
runtime.state = newSnapState;
18821954
} else {
1883-
this.update((state) => {
1884-
state.unencryptedSnapStates[snapId] = JSON.stringify(newSnapState);
1885-
});
1955+
runtime.unencryptedState = newSnapState;
18861956
}
1957+
1958+
// This is intentionally run asynchronously to avoid blocking the main
1959+
// thread.
1960+
this.#persistSnapState(snapId, newSnapState, encrypted).catch((error) => {
1961+
logError(error);
1962+
});
18871963
}
18881964

18891965
/**
@@ -1894,12 +1970,17 @@ export class SnapController extends BaseController<
18941970
* @param encrypted - A flag to indicate whether to use encrypted storage or not.
18951971
*/
18961972
clearSnapState(snapId: SnapId, encrypted: boolean) {
1897-
this.update((state) => {
1898-
if (encrypted) {
1899-
state.snapStates[snapId] = null;
1900-
} else {
1901-
state.unencryptedSnapStates[snapId] = null;
1902-
}
1973+
const runtime = this.#getRuntimeExpect(snapId);
1974+
if (encrypted) {
1975+
runtime.state = null;
1976+
} else {
1977+
runtime.unencryptedState = null;
1978+
}
1979+
1980+
// This is intentionally run asynchronously to avoid blocking the main
1981+
// thread.
1982+
this.#persistSnapState(snapId, null, encrypted).catch((error) => {
1983+
logError(error);
19031984
});
19041985
}
19051986

@@ -1912,6 +1993,13 @@ export class SnapController extends BaseController<
19121993
* @returns The requested snap state or null if no state exists.
19131994
*/
19141995
async getSnapState(snapId: SnapId, encrypted: boolean): Promise<Json> {
1996+
const runtime = this.#getRuntimeExpect(snapId);
1997+
const cachedState = encrypted ? runtime.state : runtime.unencryptedState;
1998+
1999+
if (cachedState !== undefined) {
2000+
return cachedState;
2001+
}
2002+
19152003
const state = encrypted
19162004
? this.state.snapStates[snapId]
19172005
: this.state.unencryptedSnapStates[snapId];
@@ -1921,11 +2009,17 @@ export class SnapController extends BaseController<
19212009
}
19222010

19232011
if (!encrypted) {
1924-
// For performance reasons, we do not validate that the state is JSON, since we control serialization.
1925-
return JSON.parse(state);
2012+
// For performance reasons, we do not validate that the state is JSON,
2013+
// since we control serialization.
2014+
const json = JSON.parse(state);
2015+
runtime.unencryptedState = json;
2016+
2017+
return json;
19262018
}
19272019

19282020
const decrypted = await this.#decryptSnapState(snapId, state);
2021+
runtime.state = decrypted;
2022+
19292023
return decrypted;
19302024
}
19312025

@@ -3706,6 +3800,8 @@ export class SnapController extends BaseController<
37063800
pendingOutboundRequests: 0,
37073801
interpreter,
37083802
stopping: false,
3803+
state: undefined,
3804+
unencryptedState: undefined,
37093805
});
37103806
}
37113807

packages/snaps-controllers/src/test-utils/controller.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import type { ApprovalRequest } from '@metamask/approval-controller';
2+
import type {
3+
ControllerMessenger,
4+
RestrictedControllerMessenger,
5+
} from '@metamask/base-controller';
26
import {
37
encryptWithKey,
48
decryptWithKey,
@@ -48,16 +52,17 @@ import type {
4852
SnapInterfaceControllerEvents,
4953
StoredInterface,
5054
} from '../interface/SnapInterfaceController';
55+
import { SnapController } from '../snaps';
5156
import type {
5257
AllowedActions,
5358
AllowedEvents,
5459
PersistedSnapControllerState,
5560
SnapControllerActions,
5661
SnapControllerEvents,
62+
SnapControllerStateChangeEvent,
5763
SnapsRegistryActions,
5864
SnapsRegistryEvents,
5965
} from '../snaps';
60-
import { SnapController } from '../snaps';
6166
import type { KeyDerivationOptions } from '../types';
6267
import { MOCK_CRONJOB_PERMISSION } from './cronjob';
6368
import { getNodeEES, getNodeEESMessenger } from './execution-environment';
@@ -830,3 +835,27 @@ export const getRestrictedSnapInsightsControllerMessenger = (
830835

831836
return controllerMessenger;
832837
};
838+
839+
/**
840+
* Wait for the state change event to be emitted by the messenger.
841+
*
842+
* @param messenger - The messenger to listen to.
843+
* @returns A promise that resolves when the state change event is emitted.
844+
*/
845+
export async function waitForStateChange(
846+
messenger:
847+
| ControllerMessenger<any, SnapControllerStateChangeEvent>
848+
| RestrictedControllerMessenger<
849+
'SnapController',
850+
any,
851+
SnapControllerStateChangeEvent,
852+
any,
853+
'SnapController:stateChange'
854+
>,
855+
) {
856+
return new Promise<void>((resolve) => {
857+
messenger.subscribe('SnapController:stateChange', () => {
858+
resolve();
859+
});
860+
});
861+
}

packages/snaps-utils/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"@metamask/utils": "^10.0.0",
9191
"@noble/hashes": "^1.3.1",
9292
"@scure/base": "^1.1.1",
93+
"async-mutex": "^0.4.0",
9394
"chalk": "^4.1.2",
9495
"cron-parser": "^4.5.0",
9596
"fast-deep-equal": "^3.1.3",

packages/snaps-utils/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export * from './json-rpc';
2020
export * from './localization';
2121
export * from './logging';
2222
export * from './manifest';
23+
export * from './mutex';
2324
export * from './namespace';
2425
export * from './path';
2526
export * from './platform-version';

packages/snaps-utils/src/mutex.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { Mutex } from 'async-mutex';
2+
3+
/**
4+
* Run a function with an async mutex, ensuring that only one instance of the
5+
* function can run at a time.
6+
*
7+
* @param fn - The function to run with a mutex.
8+
* @returns The wrapped function.
9+
* @template OriginalFunction - The original function type. This is inferred
10+
* from the `fn` argument, and used to determine the return type of the
11+
* wrapped function.
12+
*/
13+
export function withMutex<
14+
OriginalFunction extends (...args: any[]) => Promise<Type>,
15+
Type,
16+
>(
17+
fn: OriginalFunction,
18+
): (...args: Parameters<OriginalFunction>) => Promise<Type> {
19+
const mutex = new Mutex();
20+
21+
return async (...args: Parameters<OriginalFunction>) => {
22+
return await mutex.runExclusive(async () => await fn(...args));
23+
};
24+
}

yarn.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6280,6 +6280,7 @@ __metadata:
62806280
"@wdio/spec-reporter": "npm:^8.19.0"
62816281
"@wdio/static-server-service": "npm:^8.19.0"
62826282
"@wdio/types": "npm:^8.19.0"
6283+
async-mutex: "npm:^0.4.0"
62836284
chalk: "npm:^4.1.2"
62846285
cron-parser: "npm:^4.5.0"
62856286
deepmerge: "npm:^4.2.2"

0 commit comments

Comments
 (0)