Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions packages/snaps-rpc-methods/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ module.exports = deepmerge(baseConfig, {
],
coverageThreshold: {
global: {
branches: 95.68,
functions: 98.75,
branches: 95.7,
functions: 98.76,
lines: 98.99,
statements: 98.69,
statements: 98.7,
},
},
});
3 changes: 2 additions & 1 deletion packages/snaps-rpc-methods/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
"@metamask/snaps-utils": "workspace:^",
"@metamask/superstruct": "^3.2.1",
"@metamask/utils": "^11.8.1",
"@noble/hashes": "^1.7.1"
"@noble/hashes": "^1.7.1",
"async-mutex": "^0.5.0"
},
"devDependencies": {
"@lavamoat/allow-scripts": "^3.4.0",
Expand Down
96 changes: 92 additions & 4 deletions packages/snaps-rpc-methods/src/permitted/setState.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { JsonRpcEngine } from '@metamask/json-rpc-engine';
import { errorCodes } from '@metamask/rpc-errors';
import type { SetStateResult } from '@metamask/snaps-sdk';
import type {
Json,
JsonRpcRequest,
PendingJsonRpcResponse,
import {
createDeferredPromise,
type Json,
type JsonRpcRequest,
type PendingJsonRpcResponse,
} from '@metamask/utils';

import { setStateHandler, type SetStateParameters, set } from './setState';
Expand Down Expand Up @@ -196,6 +197,93 @@ describe('snap_setState', () => {
});
});

it('uses a mutex to protect state updates', async () => {
const { implementation } = setStateHandler;

const { promise: getStateCalled, resolve: resolveGetStateCalled } =
createDeferredPromise();
const getSnapState = jest.fn().mockImplementation(() => {
resolveGetStateCalled();
return {};
});

const { promise: updateSnapStatePromise, resolve } =
createDeferredPromise();

const updateSnapState = jest.fn().mockReturnValue(updateSnapStatePromise);
const getUnlockPromise = jest.fn().mockResolvedValue(undefined);
const hasPermission = jest.fn().mockReturnValue(true);
const getSnap = jest.fn().mockReturnValue({ preinstalled: false });

const hooks = {
getSnapState,
updateSnapState,
getUnlockPromise,
hasPermission,
getSnap,
};

const engine = new JsonRpcEngine();

engine.push((request, response, next, end) => {
const result = implementation(
request as JsonRpcRequest<SetStateParameters>,
response as PendingJsonRpcResponse<SetStateResult>,
next,
end,
hooks,
);

result?.catch(end);
});

const responsePromise1 = engine.handle({
jsonrpc: '2.0',
id: 1,
method: 'snap_setState',
params: {
key: 'foo',
value: 'baz',
encrypted: false,
},
});

const responsePromise2 = engine.handle({
jsonrpc: '2.0',
id: 2,
method: 'snap_setState',
params: {
key: 'foo',
value: 'bar',
encrypted: false,
},
});

await getStateCalled;

expect(getSnapState).toHaveBeenCalledTimes(1);

resolve();

const response1 = await responsePromise1;
const response2 = await responsePromise2;

expect(getSnapState).toHaveBeenCalledTimes(2);
expect(updateSnapState).toHaveBeenNthCalledWith(2, { foo: 'bar' }, false);

expect(response1).toStrictEqual({
jsonrpc: '2.0',
id: 1,
result: null,
});

expect(response2).toStrictEqual({
jsonrpc: '2.0',
id: 2,
result: null,
});
});

it('throws if the requesting origin does not have the required permission', async () => {
const { implementation } = setStateHandler;

Expand Down
67 changes: 48 additions & 19 deletions packages/snaps-rpc-methods/src/permitted/setState.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import type { JsonRpcEngineEndCallback } from '@metamask/json-rpc-engine';
import type { PermittedHandlerExport } from '@metamask/permission-controller';
import { providerErrors, rpcErrors } from '@metamask/rpc-errors';
import type { SetStateParams, SetStateResult } from '@metamask/snaps-sdk';
import type {
SetStateParams,
SetStateResult,
SnapId,
} from '@metamask/snaps-sdk';
import type { JsonObject } from '@metamask/snaps-sdk/jsx';
import {
getJsonSizeUnsafe,
Expand All @@ -21,6 +25,7 @@ import type {
JsonRpcRequest,
} from '@metamask/utils';
import { hasProperty, isObject, assert, JsonStruct } from '@metamask/utils';
import { Mutex } from 'async-mutex';

import {
manageStateBuilder,
Expand Down Expand Up @@ -93,6 +98,21 @@ export type SetStateHooks = {
getSnap: (snapId: string) => Snap | undefined;
};

const mutexes = new Map();

/**
* Get the corresponding state modification mutex for a given Snap ID.
*
* @param snapId - The Snap ID.
* @returns A mutex for that specific Snap.
*/
function getMutex(snapId: SnapId) {
if (!mutexes.has(snapId)) {
mutexes.set(snapId, new Mutex());
}
return mutexes.get(snapId);
}

const SetStateParametersStruct = objectStruct({
key: optional(StateKeyStruct),
value: JsonStruct,
Expand Down Expand Up @@ -156,26 +176,35 @@ async function setStateImplementation(
await getUnlockPromise(true);
}

const newState = await getNewState(key, value, encrypted, getSnapState);

const snap = getSnap(
(request as JsonRpcRequest<SetStateParams> & { origin: string }).origin,
);

if (!snap?.preinstalled) {
// We know that the state is valid JSON as per previous validation.
const size = getJsonSizeUnsafe(newState, true);
if (size > STORAGE_SIZE_LIMIT) {
throw rpcErrors.invalidParams({
message: `Invalid params: The new state must not exceed ${
STORAGE_SIZE_LIMIT / 1_000_000
} MB in size.`,
});
const snapId = (
request as JsonRpcRequest<SetStateParams> & { origin: string }
).origin as SnapId;

const mutex = getMutex(snapId);

// The expectation when using `snap_setState` is for the operation to safe
// to do in parallel. The mutex ensures that and prevents a bug that was
// mostly prevalent on mobile and caused data loss.
await mutex.runExclusive(async () => {
const newState = await getNewState(key, value, encrypted, getSnapState);

const snap = getSnap(snapId);

if (!snap?.preinstalled) {
// We know that the state is valid JSON as per previous validation.
const size = getJsonSizeUnsafe(newState, true);
if (size > STORAGE_SIZE_LIMIT) {
throw rpcErrors.invalidParams({
message: `Invalid params: The new state must not exceed ${
STORAGE_SIZE_LIMIT / 1_000_000
} MB in size.`,
});
}
}
}

await updateSnapState(newState, encrypted);
response.result = null;
await updateSnapState(newState, encrypted);
response.result = null;
});
} catch (error) {
return end(error);
}
Expand Down
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4470,6 +4470,7 @@ __metadata:
"@swc/jest": "npm:^0.2.38"
"@ts-bridge/cli": "npm:^0.6.1"
"@types/node": "npm:18.14.2"
async-mutex: "npm:^0.5.0"
deepmerge: "npm:^4.2.2"
depcheck: "npm:^1.4.7"
eslint: "npm:^9.11.0"
Expand Down
Loading