diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json index f9a4bdc274..bf7d34a45b 100644 --- a/packages/snaps-controllers/coverage.json +++ b/packages/snaps-controllers/coverage.json @@ -1,5 +1,5 @@ { - "branches": 93.44, + "branches": 93.48, "functions": 97.38, "lines": 98.34, "statements": 98.08 diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx index 6cf674e691..2812bd541a 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx +++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx @@ -19,8 +19,12 @@ import { SnapEndowments, } from '@metamask/snaps-rpc-methods'; import type { SnapId } from '@metamask/snaps-sdk'; -import { AuxiliaryFileEncoding, text } from '@metamask/snaps-sdk'; -import { Text } from '@metamask/snaps-sdk/jsx'; +import { + AuxiliaryFileEncoding, + text, + UserInputEventType, +} from '@metamask/snaps-sdk'; +import { Text, Box, Button } from '@metamask/snaps-sdk/jsx'; import type { SnapPermissions, RpcOrigins } from '@metamask/snaps-utils'; import { getPlatformVersion, @@ -2242,7 +2246,11 @@ describe('SnapController', () => { snapId: snap.id, origin: MOCK_ORIGIN, handler: HandlerType.OnUserInput, - request: { jsonrpc: '2.0', method: 'test' }, + request: { + jsonrpc: '2.0', + method: 'test', + params: { id: MOCK_INTERFACE_ID }, + }, }), ).toBeUndefined(); @@ -2770,6 +2778,77 @@ describe('SnapController', () => { snapController.destroy(); }); + it('injects context into onUserInput', async () => { + const rootMessenger = getControllerMessenger(); + const messenger = getSnapControllerMessenger(rootMessenger); + const snapController = getSnapController( + getSnapControllerOptions({ + messenger, + state: { + snaps: getPersistedSnapsState(), + }, + }), + ); + + rootMessenger.registerActionHandler( + 'SnapInterfaceController:getInterface', + () => ({ + id: MOCK_INTERFACE_ID, + snapId: MOCK_SNAP_ID, + content: ( + + + + ), + state: {}, + context: { foo: 'bar' }, + contentType: null, + }), + ); + + await snapController.handleRequest({ + snapId: MOCK_SNAP_ID, + origin: MOCK_ORIGIN, + handler: HandlerType.OnUserInput, + request: { + jsonrpc: '2.0', + method: ' ', + params: { + id: MOCK_INTERFACE_ID, + event: { + type: UserInputEventType.ButtonClickEvent, + name: 'button', + }, + }, + }, + }); + + expect(rootMessenger.call).toHaveBeenNthCalledWith( + 4, + 'ExecutionService:handleRpcRequest', + MOCK_SNAP_ID, + { + origin: MOCK_ORIGIN, + handler: HandlerType.OnUserInput, + request: { + id: expect.any(String), + method: ' ', + jsonrpc: '2.0', + params: { + id: MOCK_INTERFACE_ID, + event: { + type: UserInputEventType.ButtonClickEvent, + name: 'button', + }, + context: { foo: 'bar' }, + }, + }, + }, + ); + + snapController.destroy(); + }); + it('throws if onTransaction handler returns a phishing link', async () => { const rootMessenger = getControllerMessenger(); const messenger = getSnapControllerMessenger(rootMessenger); diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts index 49c7cdf5e1..c36bd8fb71 100644 --- a/packages/snaps-controllers/src/snaps/SnapController.ts +++ b/packages/snaps-controllers/src/snaps/SnapController.ts @@ -111,6 +111,7 @@ import type { NonEmptyArray, SemVerRange, CaipAssetType, + JsonRpcRequest, } from '@metamask/utils'; import { assert, @@ -213,7 +214,7 @@ export type PreinstalledSnap = { }; type SnapRpcHandler = ( - options: SnapRpcHookArgs & { timeout: number }, + options: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest }, ) => Promise; /** @@ -3513,7 +3514,7 @@ export class SnapController extends BaseController< handler: handlerType, request, timeout, - }: SnapRpcHookArgs & { timeout: number }) => { + }: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest }) => { if (!this.state.snaps[snapId].enabled) { throw new Error(`Snap "${snapId}" is disabled.`); } @@ -3547,13 +3548,19 @@ export class SnapController extends BaseController< } } + const transformedRequest = this.#transformSnapRpcRequest( + snapId, + handlerType, + request, + ); + const timer = new Timer(timeout); - this.#recordSnapRpcRequestStart(snapId, request.id, timer); + this.#recordSnapRpcRequestStart(snapId, transformedRequest.id, timer); const handleRpcRequestPromise = this.messagingSystem.call( 'ExecutionService:handleRpcRequest', snapId, - { origin, handler: handlerType, request }, + { origin, handler: handlerType, request: transformedRequest }, ); // This will either get the result or reject due to the timeout. @@ -3566,21 +3573,21 @@ export class SnapController extends BaseController< ); } - await this.#assertSnapRpcRequestResult(snapId, handlerType, result); + await this.#assertSnapRpcResponse(snapId, handlerType, result); - const transformedResult = await this.#transformSnapRpcRequestResult( + const transformedResult = await this.#transformSnapRpcResponse( snapId, handlerType, - request, + transformedRequest, result, ); - this.#recordSnapRpcRequestFinish(snapId, request.id); + this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id); return transformedResult; } catch (error) { // We flag the RPC request as finished early since termination may affect pending requests - this.#recordSnapRpcRequestFinish(snapId, request.id); + this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id); const [jsonRpcError, handled] = unwrapError(error); if (!handled) { @@ -3629,15 +3636,15 @@ export class SnapController extends BaseController< } /** - * Transform a RPC request result if necessary. + * Transform a RPC response if necessary. * * @param snapId - The snap ID of the snap that produced the result. * @param handlerType - The handler type that produced the result. * @param request - The request that returned the result. - * @param result - The result. + * @param result - The response. * @returns The transformed result if applicable, otherwise the original result. */ - async #transformSnapRpcRequestResult( + async #transformSnapRpcResponse( snapId: SnapId, handlerType: HandlerType, request: Record, @@ -3763,6 +3770,42 @@ export class SnapController extends BaseController< return { conversionRates: filteredConversionRates }; } + /** + * Transforms a JSON-RPC request before sending it to the Snap, if required for a given handler. + * + * @param snapId - The Snap ID. + * @param handlerType - The handler being called. + * @param request - The JSON-RPC request. + * @returns The potentially transformed JSON-RPC request. + */ + #transformSnapRpcRequest( + snapId: SnapId, + handlerType: HandlerType, + request: JsonRpcRequest, + ) { + switch (handlerType) { + // For onUserInput we inject context, so the client doesn't have to worry about keeping it in sync. + case HandlerType.OnUserInput: { + assert(request.params && hasProperty(request.params, 'id')); + + const interfaceId = request.params.id as string; + const { context } = this.messagingSystem.call( + 'SnapInterfaceController:getInterface', + snapId, + interfaceId, + ); + + return { + ...request, + params: { ...request.params, context }, + }; + } + + default: + return request; + } + } + /** * Assert that the returned result of a Snap RPC call is the expected shape. * @@ -3770,7 +3813,7 @@ export class SnapController extends BaseController< * @param handlerType - The handler type of the RPC Request. * @param result - The result of the RPC request. */ - async #assertSnapRpcRequestResult( + async #assertSnapRpcResponse( snapId: SnapId, handlerType: HandlerType, result: unknown, diff --git a/packages/snaps-controllers/src/test-utils/controller.ts b/packages/snaps-controllers/src/test-utils/controller.ts index f6c162cbb7..dfbfc27f2f 100644 --- a/packages/snaps-controllers/src/test-utils/controller.ts +++ b/packages/snaps-controllers/src/test-utils/controller.ts @@ -22,7 +22,8 @@ import { SnapEndowments, WALLET_SNAP_PERMISSION_KEY, } from '@metamask/snaps-rpc-methods'; -import type { SnapId, text } from '@metamask/snaps-sdk'; +import type { SnapId } from '@metamask/snaps-sdk'; +import { text } from '@metamask/snaps-sdk'; import { SnapCaveatType } from '@metamask/snaps-utils'; import { MockControllerMessenger, @@ -440,7 +441,12 @@ export const getControllerMessenger = (registry = new MockSnapsRegistry()) => { if (id !== MOCK_INTERFACE_ID) { throw new Error(`Interface with id '${id}' not found.`); } - return { snapId, content: text('foo bar'), state: {} } as StoredInterface; + return { + snapId, + content: text('foo bar'), + state: {}, + context: null, + } as StoredInterface; }, );