diff --git a/packages/snaps-controllers/coverage.json b/packages/snaps-controllers/coverage.json
index f9a4bdc274..bf7d34a45b 100644
--- a/packages/snaps-controllers/coverage.json
+++ b/packages/snaps-controllers/coverage.json
@@ -1,5 +1,5 @@
{
- "branches": 93.44,
+ "branches": 93.48,
"functions": 97.38,
"lines": 98.34,
"statements": 98.08
diff --git a/packages/snaps-controllers/src/snaps/SnapController.test.tsx b/packages/snaps-controllers/src/snaps/SnapController.test.tsx
index 6cf674e691..2812bd541a 100644
--- a/packages/snaps-controllers/src/snaps/SnapController.test.tsx
+++ b/packages/snaps-controllers/src/snaps/SnapController.test.tsx
@@ -19,8 +19,12 @@ import {
SnapEndowments,
} from '@metamask/snaps-rpc-methods';
import type { SnapId } from '@metamask/snaps-sdk';
-import { AuxiliaryFileEncoding, text } from '@metamask/snaps-sdk';
-import { Text } from '@metamask/snaps-sdk/jsx';
+import {
+ AuxiliaryFileEncoding,
+ text,
+ UserInputEventType,
+} from '@metamask/snaps-sdk';
+import { Text, Box, Button } from '@metamask/snaps-sdk/jsx';
import type { SnapPermissions, RpcOrigins } from '@metamask/snaps-utils';
import {
getPlatformVersion,
@@ -2242,7 +2246,11 @@ describe('SnapController', () => {
snapId: snap.id,
origin: MOCK_ORIGIN,
handler: HandlerType.OnUserInput,
- request: { jsonrpc: '2.0', method: 'test' },
+ request: {
+ jsonrpc: '2.0',
+ method: 'test',
+ params: { id: MOCK_INTERFACE_ID },
+ },
}),
).toBeUndefined();
@@ -2770,6 +2778,77 @@ describe('SnapController', () => {
snapController.destroy();
});
+ it('injects context into onUserInput', async () => {
+ const rootMessenger = getControllerMessenger();
+ const messenger = getSnapControllerMessenger(rootMessenger);
+ const snapController = getSnapController(
+ getSnapControllerOptions({
+ messenger,
+ state: {
+ snaps: getPersistedSnapsState(),
+ },
+ }),
+ );
+
+ rootMessenger.registerActionHandler(
+ 'SnapInterfaceController:getInterface',
+ () => ({
+ id: MOCK_INTERFACE_ID,
+ snapId: MOCK_SNAP_ID,
+ content: (
+
+
+
+ ),
+ state: {},
+ context: { foo: 'bar' },
+ contentType: null,
+ }),
+ );
+
+ await snapController.handleRequest({
+ snapId: MOCK_SNAP_ID,
+ origin: MOCK_ORIGIN,
+ handler: HandlerType.OnUserInput,
+ request: {
+ jsonrpc: '2.0',
+ method: ' ',
+ params: {
+ id: MOCK_INTERFACE_ID,
+ event: {
+ type: UserInputEventType.ButtonClickEvent,
+ name: 'button',
+ },
+ },
+ },
+ });
+
+ expect(rootMessenger.call).toHaveBeenNthCalledWith(
+ 4,
+ 'ExecutionService:handleRpcRequest',
+ MOCK_SNAP_ID,
+ {
+ origin: MOCK_ORIGIN,
+ handler: HandlerType.OnUserInput,
+ request: {
+ id: expect.any(String),
+ method: ' ',
+ jsonrpc: '2.0',
+ params: {
+ id: MOCK_INTERFACE_ID,
+ event: {
+ type: UserInputEventType.ButtonClickEvent,
+ name: 'button',
+ },
+ context: { foo: 'bar' },
+ },
+ },
+ },
+ );
+
+ snapController.destroy();
+ });
+
it('throws if onTransaction handler returns a phishing link', async () => {
const rootMessenger = getControllerMessenger();
const messenger = getSnapControllerMessenger(rootMessenger);
diff --git a/packages/snaps-controllers/src/snaps/SnapController.ts b/packages/snaps-controllers/src/snaps/SnapController.ts
index 49c7cdf5e1..c36bd8fb71 100644
--- a/packages/snaps-controllers/src/snaps/SnapController.ts
+++ b/packages/snaps-controllers/src/snaps/SnapController.ts
@@ -111,6 +111,7 @@ import type {
NonEmptyArray,
SemVerRange,
CaipAssetType,
+ JsonRpcRequest,
} from '@metamask/utils';
import {
assert,
@@ -213,7 +214,7 @@ export type PreinstalledSnap = {
};
type SnapRpcHandler = (
- options: SnapRpcHookArgs & { timeout: number },
+ options: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest },
) => Promise;
/**
@@ -3513,7 +3514,7 @@ export class SnapController extends BaseController<
handler: handlerType,
request,
timeout,
- }: SnapRpcHookArgs & { timeout: number }) => {
+ }: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest }) => {
if (!this.state.snaps[snapId].enabled) {
throw new Error(`Snap "${snapId}" is disabled.`);
}
@@ -3547,13 +3548,19 @@ export class SnapController extends BaseController<
}
}
+ const transformedRequest = this.#transformSnapRpcRequest(
+ snapId,
+ handlerType,
+ request,
+ );
+
const timer = new Timer(timeout);
- this.#recordSnapRpcRequestStart(snapId, request.id, timer);
+ this.#recordSnapRpcRequestStart(snapId, transformedRequest.id, timer);
const handleRpcRequestPromise = this.messagingSystem.call(
'ExecutionService:handleRpcRequest',
snapId,
- { origin, handler: handlerType, request },
+ { origin, handler: handlerType, request: transformedRequest },
);
// This will either get the result or reject due to the timeout.
@@ -3566,21 +3573,21 @@ export class SnapController extends BaseController<
);
}
- await this.#assertSnapRpcRequestResult(snapId, handlerType, result);
+ await this.#assertSnapRpcResponse(snapId, handlerType, result);
- const transformedResult = await this.#transformSnapRpcRequestResult(
+ const transformedResult = await this.#transformSnapRpcResponse(
snapId,
handlerType,
- request,
+ transformedRequest,
result,
);
- this.#recordSnapRpcRequestFinish(snapId, request.id);
+ this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id);
return transformedResult;
} catch (error) {
// We flag the RPC request as finished early since termination may affect pending requests
- this.#recordSnapRpcRequestFinish(snapId, request.id);
+ this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id);
const [jsonRpcError, handled] = unwrapError(error);
if (!handled) {
@@ -3629,15 +3636,15 @@ export class SnapController extends BaseController<
}
/**
- * Transform a RPC request result if necessary.
+ * Transform a RPC response if necessary.
*
* @param snapId - The snap ID of the snap that produced the result.
* @param handlerType - The handler type that produced the result.
* @param request - The request that returned the result.
- * @param result - The result.
+ * @param result - The response.
* @returns The transformed result if applicable, otherwise the original result.
*/
- async #transformSnapRpcRequestResult(
+ async #transformSnapRpcResponse(
snapId: SnapId,
handlerType: HandlerType,
request: Record,
@@ -3763,6 +3770,42 @@ export class SnapController extends BaseController<
return { conversionRates: filteredConversionRates };
}
+ /**
+ * Transforms a JSON-RPC request before sending it to the Snap, if required for a given handler.
+ *
+ * @param snapId - The Snap ID.
+ * @param handlerType - The handler being called.
+ * @param request - The JSON-RPC request.
+ * @returns The potentially transformed JSON-RPC request.
+ */
+ #transformSnapRpcRequest(
+ snapId: SnapId,
+ handlerType: HandlerType,
+ request: JsonRpcRequest,
+ ) {
+ switch (handlerType) {
+ // For onUserInput we inject context, so the client doesn't have to worry about keeping it in sync.
+ case HandlerType.OnUserInput: {
+ assert(request.params && hasProperty(request.params, 'id'));
+
+ const interfaceId = request.params.id as string;
+ const { context } = this.messagingSystem.call(
+ 'SnapInterfaceController:getInterface',
+ snapId,
+ interfaceId,
+ );
+
+ return {
+ ...request,
+ params: { ...request.params, context },
+ };
+ }
+
+ default:
+ return request;
+ }
+ }
+
/**
* Assert that the returned result of a Snap RPC call is the expected shape.
*
@@ -3770,7 +3813,7 @@ export class SnapController extends BaseController<
* @param handlerType - The handler type of the RPC Request.
* @param result - The result of the RPC request.
*/
- async #assertSnapRpcRequestResult(
+ async #assertSnapRpcResponse(
snapId: SnapId,
handlerType: HandlerType,
result: unknown,
diff --git a/packages/snaps-controllers/src/test-utils/controller.ts b/packages/snaps-controllers/src/test-utils/controller.ts
index f6c162cbb7..dfbfc27f2f 100644
--- a/packages/snaps-controllers/src/test-utils/controller.ts
+++ b/packages/snaps-controllers/src/test-utils/controller.ts
@@ -22,7 +22,8 @@ import {
SnapEndowments,
WALLET_SNAP_PERMISSION_KEY,
} from '@metamask/snaps-rpc-methods';
-import type { SnapId, text } from '@metamask/snaps-sdk';
+import type { SnapId } from '@metamask/snaps-sdk';
+import { text } from '@metamask/snaps-sdk';
import { SnapCaveatType } from '@metamask/snaps-utils';
import {
MockControllerMessenger,
@@ -440,7 +441,12 @@ export const getControllerMessenger = (registry = new MockSnapsRegistry()) => {
if (id !== MOCK_INTERFACE_ID) {
throw new Error(`Interface with id '${id}' not found.`);
}
- return { snapId, content: text('foo bar'), state: {} } as StoredInterface;
+ return {
+ snapId,
+ content: text('foo bar'),
+ state: {},
+ context: null,
+ } as StoredInterface;
},
);