From dd60a65c90dd74cfc9d4638078f0ae779477318a Mon Sep 17 00:00:00 2001 From: Frederik Bolding Date: Tue, 18 Nov 2025 15:28:29 +0100 Subject: [PATCH 1/3] fix: Use mutex when modifying state using snap_setState --- packages/snaps-rpc-methods/package.json | 3 +- .../src/permitted/setState.ts | 64 +++++++++++++------ yarn.lock | 1 + 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/packages/snaps-rpc-methods/package.json b/packages/snaps-rpc-methods/package.json index 4eee0baae5..192e9e1552 100644 --- a/packages/snaps-rpc-methods/package.json +++ b/packages/snaps-rpc-methods/package.json @@ -62,7 +62,8 @@ "@metamask/snaps-utils": "workspace:^", "@metamask/superstruct": "^3.2.1", "@metamask/utils": "^11.8.1", - "@noble/hashes": "^1.7.1" + "@noble/hashes": "^1.7.1", + "async-mutex": "^0.5.0" }, "devDependencies": { "@lavamoat/allow-scripts": "^3.4.0", diff --git a/packages/snaps-rpc-methods/src/permitted/setState.ts b/packages/snaps-rpc-methods/src/permitted/setState.ts index b03e2c11d8..32b29ce65e 100644 --- a/packages/snaps-rpc-methods/src/permitted/setState.ts +++ b/packages/snaps-rpc-methods/src/permitted/setState.ts @@ -1,7 +1,11 @@ import type { JsonRpcEngineEndCallback } from '@metamask/json-rpc-engine'; import type { PermittedHandlerExport } from '@metamask/permission-controller'; import { providerErrors, rpcErrors } from '@metamask/rpc-errors'; -import type { SetStateParams, SetStateResult } from '@metamask/snaps-sdk'; +import type { + SetStateParams, + SetStateResult, + SnapId, +} from '@metamask/snaps-sdk'; import type { JsonObject } from '@metamask/snaps-sdk/jsx'; import { getJsonSizeUnsafe, @@ -21,6 +25,7 @@ import type { JsonRpcRequest, } from '@metamask/utils'; import { hasProperty, isObject, assert, JsonStruct } from '@metamask/utils'; +import { Mutex } from 'async-mutex'; import { manageStateBuilder, @@ -93,6 +98,21 @@ export type SetStateHooks = { getSnap: (snapId: string) => Snap | undefined; }; +const mutexes = new Map(); + +/** + * Get the corresponding state modification mutex for a given Snap ID. + * + * @param snapId - The Snap ID. + * @returns A mutex for that specific Snap. + */ +function getMutex(snapId: SnapId) { + if (!mutexes.has(snapId)) { + mutexes.set(snapId, new Mutex()); + } + return mutexes.get(snapId); +} + const SetStateParametersStruct = objectStruct({ key: optional(StateKeyStruct), value: JsonStruct, @@ -156,26 +176,32 @@ async function setStateImplementation( await getUnlockPromise(true); } - const newState = await getNewState(key, value, encrypted, getSnapState); - - const snap = getSnap( - (request as JsonRpcRequest & { origin: string }).origin, - ); - - if (!snap?.preinstalled) { - // We know that the state is valid JSON as per previous validation. - const size = getJsonSizeUnsafe(newState, true); - if (size > STORAGE_SIZE_LIMIT) { - throw rpcErrors.invalidParams({ - message: `Invalid params: The new state must not exceed ${ - STORAGE_SIZE_LIMIT / 1_000_000 - } MB in size.`, - }); + const snapId = ( + request as JsonRpcRequest & { origin: string } + ).origin as SnapId; + + const mutex = getMutex(snapId); + + await mutex.runExclusive(async () => { + const newState = await getNewState(key, value, encrypted, getSnapState); + + const snap = getSnap(snapId); + + if (!snap?.preinstalled) { + // We know that the state is valid JSON as per previous validation. + const size = getJsonSizeUnsafe(newState, true); + if (size > STORAGE_SIZE_LIMIT) { + throw rpcErrors.invalidParams({ + message: `Invalid params: The new state must not exceed ${ + STORAGE_SIZE_LIMIT / 1_000_000 + } MB in size.`, + }); + } } - } - await updateSnapState(newState, encrypted); - response.result = null; + await updateSnapState(newState, encrypted); + response.result = null; + }); } catch (error) { return end(error); } diff --git a/yarn.lock b/yarn.lock index 78eda15f36..a582e97edc 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4470,6 +4470,7 @@ __metadata: "@swc/jest": "npm:^0.2.38" "@ts-bridge/cli": "npm:^0.6.1" "@types/node": "npm:18.14.2" + async-mutex: "npm:^0.5.0" deepmerge: "npm:^4.2.2" depcheck: "npm:^1.4.7" eslint: "npm:^9.11.0" From ac4fa3fd5c92935bdbdd7a321e4e603486f9a9c4 Mon Sep 17 00:00:00 2001 From: Frederik Bolding Date: Tue, 18 Nov 2025 15:34:01 +0100 Subject: [PATCH 2/3] Add comment --- packages/snaps-rpc-methods/src/permitted/setState.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/snaps-rpc-methods/src/permitted/setState.ts b/packages/snaps-rpc-methods/src/permitted/setState.ts index 32b29ce65e..86cf37c15b 100644 --- a/packages/snaps-rpc-methods/src/permitted/setState.ts +++ b/packages/snaps-rpc-methods/src/permitted/setState.ts @@ -182,6 +182,9 @@ async function setStateImplementation( const mutex = getMutex(snapId); + // The expectation when using `snap_setState` is for the operation to safe + // to do in parallel. The mutex ensures that and prevents a bug that was + // mostly prevalent on mobile and caused data loss. await mutex.runExclusive(async () => { const newState = await getNewState(key, value, encrypted, getSnapState); From 358aad00d44d8737a3a58fb19b69553c3f1a5842 Mon Sep 17 00:00:00 2001 From: Frederik Bolding Date: Tue, 18 Nov 2025 15:55:03 +0100 Subject: [PATCH 3/3] Add test --- packages/snaps-rpc-methods/jest.config.js | 6 +- .../src/permitted/setState.test.ts | 96 ++++++++++++++++++- 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/packages/snaps-rpc-methods/jest.config.js b/packages/snaps-rpc-methods/jest.config.js index 6cadef39ec..f4af002854 100644 --- a/packages/snaps-rpc-methods/jest.config.js +++ b/packages/snaps-rpc-methods/jest.config.js @@ -10,10 +10,10 @@ module.exports = deepmerge(baseConfig, { ], coverageThreshold: { global: { - branches: 95.68, - functions: 98.75, + branches: 95.7, + functions: 98.76, lines: 98.99, - statements: 98.69, + statements: 98.7, }, }, }); diff --git a/packages/snaps-rpc-methods/src/permitted/setState.test.ts b/packages/snaps-rpc-methods/src/permitted/setState.test.ts index c993f54be1..59a4e484b6 100644 --- a/packages/snaps-rpc-methods/src/permitted/setState.test.ts +++ b/packages/snaps-rpc-methods/src/permitted/setState.test.ts @@ -1,10 +1,11 @@ import { JsonRpcEngine } from '@metamask/json-rpc-engine'; import { errorCodes } from '@metamask/rpc-errors'; import type { SetStateResult } from '@metamask/snaps-sdk'; -import type { - Json, - JsonRpcRequest, - PendingJsonRpcResponse, +import { + createDeferredPromise, + type Json, + type JsonRpcRequest, + type PendingJsonRpcResponse, } from '@metamask/utils'; import { setStateHandler, type SetStateParameters, set } from './setState'; @@ -196,6 +197,93 @@ describe('snap_setState', () => { }); }); + it('uses a mutex to protect state updates', async () => { + const { implementation } = setStateHandler; + + const { promise: getStateCalled, resolve: resolveGetStateCalled } = + createDeferredPromise(); + const getSnapState = jest.fn().mockImplementation(() => { + resolveGetStateCalled(); + return {}; + }); + + const { promise: updateSnapStatePromise, resolve } = + createDeferredPromise(); + + const updateSnapState = jest.fn().mockReturnValue(updateSnapStatePromise); + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const hasPermission = jest.fn().mockReturnValue(true); + const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + + const hooks = { + getSnapState, + updateSnapState, + getUnlockPromise, + hasPermission, + getSnap, + }; + + const engine = new JsonRpcEngine(); + + engine.push((request, response, next, end) => { + const result = implementation( + request as JsonRpcRequest, + response as PendingJsonRpcResponse, + next, + end, + hooks, + ); + + result?.catch(end); + }); + + const responsePromise1 = engine.handle({ + jsonrpc: '2.0', + id: 1, + method: 'snap_setState', + params: { + key: 'foo', + value: 'baz', + encrypted: false, + }, + }); + + const responsePromise2 = engine.handle({ + jsonrpc: '2.0', + id: 2, + method: 'snap_setState', + params: { + key: 'foo', + value: 'bar', + encrypted: false, + }, + }); + + await getStateCalled; + + expect(getSnapState).toHaveBeenCalledTimes(1); + + resolve(); + + const response1 = await responsePromise1; + const response2 = await responsePromise2; + + expect(getSnapState).toHaveBeenCalledTimes(2); + expect(updateSnapState).toHaveBeenNthCalledWith(2, { foo: 'bar' }, false); + + expect(response1).toStrictEqual({ + jsonrpc: '2.0', + id: 1, + result: null, + }); + + expect(response2).toStrictEqual({ + jsonrpc: '2.0', + id: 2, + result: null, + }); + }); + it('throws if the requesting origin does not have the required permission', async () => { const { implementation } = setStateHandler;