Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/snaps-controllers/coverage.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"branches": 93.44,
"branches": 93.48,
"functions": 97.38,
"lines": 98.34,
"statements": 98.08
Expand Down
85 changes: 82 additions & 3 deletions packages/snaps-controllers/src/snaps/SnapController.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ import {
SnapEndowments,
} from '@metamask/snaps-rpc-methods';
import type { SnapId } from '@metamask/snaps-sdk';
import { AuxiliaryFileEncoding, text } from '@metamask/snaps-sdk';
import { Text } from '@metamask/snaps-sdk/jsx';
import {
AuxiliaryFileEncoding,
text,
UserInputEventType,
} from '@metamask/snaps-sdk';
import { Text, Box, Button } from '@metamask/snaps-sdk/jsx';
import type { SnapPermissions, RpcOrigins } from '@metamask/snaps-utils';
import {
getPlatformVersion,
Expand Down Expand Up @@ -2242,7 +2246,11 @@ describe('SnapController', () => {
snapId: snap.id,
origin: MOCK_ORIGIN,
handler: HandlerType.OnUserInput,
request: { jsonrpc: '2.0', method: 'test' },
request: {
jsonrpc: '2.0',
method: 'test',
params: { id: MOCK_INTERFACE_ID },
},
}),
).toBeUndefined();

Expand Down Expand Up @@ -2770,6 +2778,77 @@ describe('SnapController', () => {
snapController.destroy();
});

it('injects context into onUserInput', async () => {
const rootMessenger = getControllerMessenger();
const messenger = getSnapControllerMessenger(rootMessenger);
const snapController = getSnapController(
getSnapControllerOptions({
messenger,
state: {
snaps: getPersistedSnapsState(),
},
}),
);

rootMessenger.registerActionHandler(
'SnapInterfaceController:getInterface',
() => ({
id: MOCK_INTERFACE_ID,
snapId: MOCK_SNAP_ID,
content: (
<Box>
<Button name="button">Click me</Button>
</Box>
),
state: {},
context: { foo: 'bar' },
contentType: null,
}),
);

await snapController.handleRequest({
snapId: MOCK_SNAP_ID,
origin: MOCK_ORIGIN,
handler: HandlerType.OnUserInput,
request: {
jsonrpc: '2.0',
method: ' ',
params: {
id: MOCK_INTERFACE_ID,
event: {
type: UserInputEventType.ButtonClickEvent,
name: 'button',
},
},
},
});

expect(rootMessenger.call).toHaveBeenNthCalledWith(
4,
'ExecutionService:handleRpcRequest',
MOCK_SNAP_ID,
{
origin: MOCK_ORIGIN,
handler: HandlerType.OnUserInput,
request: {
id: expect.any(String),
method: ' ',
jsonrpc: '2.0',
params: {
id: MOCK_INTERFACE_ID,
event: {
type: UserInputEventType.ButtonClickEvent,
name: 'button',
},
context: { foo: 'bar' },
},
},
},
);

snapController.destroy();
});

it('throws if onTransaction handler returns a phishing link', async () => {
const rootMessenger = getControllerMessenger();
const messenger = getSnapControllerMessenger(rootMessenger);
Expand Down
69 changes: 56 additions & 13 deletions packages/snaps-controllers/src/snaps/SnapController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ import type {
NonEmptyArray,
SemVerRange,
CaipAssetType,
JsonRpcRequest,
} from '@metamask/utils';
import {
assert,
Expand Down Expand Up @@ -213,7 +214,7 @@ export type PreinstalledSnap = {
};

type SnapRpcHandler = (
options: SnapRpcHookArgs & { timeout: number },
options: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest },
) => Promise<unknown>;

/**
Expand Down Expand Up @@ -3513,7 +3514,7 @@ export class SnapController extends BaseController<
handler: handlerType,
request,
timeout,
}: SnapRpcHookArgs & { timeout: number }) => {
}: SnapRpcHookArgs & { timeout: number; request: JsonRpcRequest }) => {
if (!this.state.snaps[snapId].enabled) {
throw new Error(`Snap "${snapId}" is disabled.`);
}
Expand Down Expand Up @@ -3547,13 +3548,19 @@ export class SnapController extends BaseController<
}
}

const transformedRequest = this.#transformSnapRpcRequest(
snapId,
handlerType,
request,
);

const timer = new Timer(timeout);
this.#recordSnapRpcRequestStart(snapId, request.id, timer);
this.#recordSnapRpcRequestStart(snapId, transformedRequest.id, timer);

const handleRpcRequestPromise = this.messagingSystem.call(
'ExecutionService:handleRpcRequest',
snapId,
{ origin, handler: handlerType, request },
{ origin, handler: handlerType, request: transformedRequest },
);

// This will either get the result or reject due to the timeout.
Expand All @@ -3566,21 +3573,21 @@ export class SnapController extends BaseController<
);
}

await this.#assertSnapRpcRequestResult(snapId, handlerType, result);
await this.#assertSnapRpcResponse(snapId, handlerType, result);

const transformedResult = await this.#transformSnapRpcRequestResult(
const transformedResult = await this.#transformSnapRpcResponse(
snapId,
handlerType,
request,
transformedRequest,
result,
);

this.#recordSnapRpcRequestFinish(snapId, request.id);
this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id);

return transformedResult;
} catch (error) {
// We flag the RPC request as finished early since termination may affect pending requests
this.#recordSnapRpcRequestFinish(snapId, request.id);
this.#recordSnapRpcRequestFinish(snapId, transformedRequest.id);
const [jsonRpcError, handled] = unwrapError(error);

if (!handled) {
Expand Down Expand Up @@ -3629,15 +3636,15 @@ export class SnapController extends BaseController<
}

/**
* Transform a RPC request result if necessary.
* Transform a RPC response if necessary.
*
* @param snapId - The snap ID of the snap that produced the result.
* @param handlerType - The handler type that produced the result.
* @param request - The request that returned the result.
* @param result - The result.
* @param result - The response.
* @returns The transformed result if applicable, otherwise the original result.
*/
async #transformSnapRpcRequestResult(
async #transformSnapRpcResponse(
snapId: SnapId,
handlerType: HandlerType,
request: Record<string, unknown>,
Expand Down Expand Up @@ -3763,14 +3770,50 @@ export class SnapController extends BaseController<
return { conversionRates: filteredConversionRates };
}

/**
* Transforms a JSON-RPC request before sending it to the Snap, if required for a given handler.
*
* @param snapId - The Snap ID.
* @param handlerType - The handler being called.
* @param request - The JSON-RPC request.
* @returns The potentially transformed JSON-RPC request.
*/
#transformSnapRpcRequest(
snapId: SnapId,
handlerType: HandlerType,
request: JsonRpcRequest,
) {
switch (handlerType) {
// For onUserInput we inject context, so the client doesn't have to worry about keeping it in sync.
case HandlerType.OnUserInput: {
assert(request.params && hasProperty(request.params, 'id'));

const interfaceId = request.params.id as string;
const { context } = this.messagingSystem.call(
'SnapInterfaceController:getInterface',
snapId,
interfaceId,
);

return {
...request,
params: { ...request.params, context },
};
}

default:
return request;
}
}

/**
* Assert that the returned result of a Snap RPC call is the expected shape.
*
* @param snapId - The snap ID.
* @param handlerType - The handler type of the RPC Request.
* @param result - The result of the RPC request.
*/
async #assertSnapRpcRequestResult(
async #assertSnapRpcResponse(
snapId: SnapId,
handlerType: HandlerType,
result: unknown,
Expand Down
10 changes: 8 additions & 2 deletions packages/snaps-controllers/src/test-utils/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import {
SnapEndowments,
WALLET_SNAP_PERMISSION_KEY,
} from '@metamask/snaps-rpc-methods';
import type { SnapId, text } from '@metamask/snaps-sdk';
import type { SnapId } from '@metamask/snaps-sdk';
import { text } from '@metamask/snaps-sdk';
import { SnapCaveatType } from '@metamask/snaps-utils';
import {
MockControllerMessenger,
Expand Down Expand Up @@ -440,7 +441,12 @@ export const getControllerMessenger = (registry = new MockSnapsRegistry()) => {
if (id !== MOCK_INTERFACE_ID) {
throw new Error(`Interface with id '${id}' not found.`);
}
return { snapId, content: text('foo bar'), state: {} } as StoredInterface;
return {
snapId,
content: text('foo bar'),
state: {},
context: null,
} as StoredInterface;
},
);

Expand Down
Loading