Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions packages/snaps-controllers/src/snaps/SnapController.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ describe('SnapController', () => {
[MOCK_SNAP_ID]: expectedSnapObject,
});

expect(messenger.call).toHaveBeenCalledTimes(10);
expect(messenger.call).toHaveBeenCalledTimes(9);

expect(messenger.call).toHaveBeenNthCalledWith(
1,
Expand Down Expand Up @@ -1869,10 +1869,12 @@ describe('SnapController', () => {
},
});

const results = await Promise.allSettled([
snapController.removeSnap(snap.id),
promise,
]);
const removeSnap = async () => {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hack, but unsure how else to reproduce the race condition now

await sleep(1);
return snapController.removeSnap(snap.id);
};

const results = await Promise.allSettled([removeSnap(), promise]);

expect(results[0].status).toBe('fulfilled');
expect(results[1].status).toBe('rejected');
Expand Down Expand Up @@ -2681,6 +2683,55 @@ describe('SnapController', () => {
},
);

it('ensures onboarding has completed before processing requests', async () => {
const rootMessenger = getControllerMessenger();
const messenger = getSnapControllerMessenger(rootMessenger);

const callActionSpy = jest.spyOn(messenger, 'call');

const { promise, resolve } = createDeferredPromise();
const ensureOnboardingComplete = jest.fn().mockReturnValue(promise);
const snapController = getSnapController(
getSnapControllerOptions({
messenger,
state: {
snaps: getPersistedSnapsState(),
},
ensureOnboardingComplete,
}),
);

const snap = snapController.getExpect(MOCK_SNAP_ID);

const requestPromise = snapController.handleRequest({
snapId: snap.id,
origin: METAMASK_ORIGIN,
handler: HandlerType.OnRpcRequest,
request: {
jsonrpc: '2.0',
method: 'test',
params: {},
},
});

await sleep(100);

expect(callActionSpy).not.toHaveBeenCalledWith(
'ExecutionService:executeSnap',
expect.objectContaining({ snapId: MOCK_SNAP_ID }),
);

resolve();
expect(await requestPromise).toBeUndefined();

expect(callActionSpy).toHaveBeenCalledWith(
'ExecutionService:executeSnap',
expect.objectContaining({ snapId: MOCK_SNAP_ID }),
);

snapController.destroy();
});

it('throws if the snap does not have permission to handle JSON-RPC requests from dapps', async () => {
const rootMessenger = getControllerMessenger();
const messenger = getSnapControllerMessenger(rootMessenger);
Expand Down Expand Up @@ -5448,7 +5499,7 @@ describe('SnapController', () => {

expect(result).toStrictEqual({ [MOCK_LOCAL_SNAP_ID]: truncatedSnap });

expect(messenger.call).toHaveBeenCalledTimes(12);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to be changes in timing for when lifecycle hooks are called, due to the added async code we don't get as long in processing the side-effect before we reach the assertion.

expect(messenger.call).toHaveBeenCalledTimes(11);

expect(messenger.call).toHaveBeenNthCalledWith(
1,
Expand Down Expand Up @@ -5598,7 +5649,7 @@ describe('SnapController', () => {
[MOCK_LOCAL_SNAP_ID]: truncatedSnap,
});

expect(messenger.call).toHaveBeenCalledTimes(23);
expect(messenger.call).toHaveBeenCalledTimes(22);

expect(messenger.call).toHaveBeenNthCalledWith(
1,
Expand Down Expand Up @@ -6845,7 +6896,7 @@ describe('SnapController', () => {
expect(result).toStrictEqual({
[MOCK_SNAP_ID]: truncatedSnap,
});
expect(messenger.call).toHaveBeenCalledTimes(10);
expect(messenger.call).toHaveBeenCalledTimes(9);

expect(messenger.call).toHaveBeenNthCalledWith(
1,
Expand Down Expand Up @@ -7538,7 +7589,7 @@ describe('SnapController', () => {
[MOCK_SNAP_ID]: { version: newVersionRange },
});

expect(messenger.call).toHaveBeenCalledTimes(21);
expect(messenger.call).toHaveBeenCalledTimes(20);

expect(messenger.call).toHaveBeenNthCalledWith(
3,
Expand Down Expand Up @@ -8414,7 +8465,7 @@ describe('SnapController', () => {
date: expect.any(Number),
},
]);
expect(callActionSpy).toHaveBeenCalledTimes(21);
expect(callActionSpy).toHaveBeenCalledTimes(20);

expect(callActionSpy).toHaveBeenNthCalledWith(
12,
Expand Down Expand Up @@ -8579,7 +8630,7 @@ describe('SnapController', () => {
[MOCK_SNAP_ID]: { version: '1.1.0' },
});

expect(callActionSpy).toHaveBeenCalledTimes(21);
expect(callActionSpy).toHaveBeenCalledTimes(20);
expect(callActionSpy).toHaveBeenNthCalledWith(
12,
'ApprovalController:addRequest',
Expand Down Expand Up @@ -8695,7 +8746,7 @@ describe('SnapController', () => {
[MOCK_SNAP_ID]: { version: '1.1.0' },
});

expect(callActionSpy).toHaveBeenCalledTimes(22);
expect(callActionSpy).toHaveBeenCalledTimes(21);
expect(callActionSpy).toHaveBeenNthCalledWith(
12,
'ApprovalController:addRequest',
Expand Down Expand Up @@ -8836,7 +8887,7 @@ describe('SnapController', () => {

const isRunning = controller.isRunning(MOCK_SNAP_ID);

expect(callActionSpy).toHaveBeenCalledTimes(12);
expect(callActionSpy).toHaveBeenCalledTimes(11);

expect(callActionSpy).toHaveBeenNthCalledWith(
1,
Expand Down Expand Up @@ -9191,7 +9242,7 @@ describe('SnapController', () => {
[MOCK_SNAP_ID]: { version: '1.1.0' },
});

expect(callActionSpy).toHaveBeenCalledTimes(23);
expect(callActionSpy).toHaveBeenCalledTimes(22);

expect(callActionSpy).toHaveBeenNthCalledWith(
12,
Expand Down
28 changes: 21 additions & 7 deletions packages/snaps-controllers/src/snaps/SnapController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,13 @@ type SnapControllerArgs = {
* MetaMetrics event tracking hook.
*/
trackEvent: TrackEventHook;

/**
* A hook that returns a promise that resolves when the onboarding has completed.
*
* @returns A promise that resolves when onboarding is complete.
*/
ensureOnboardingComplete: () => Promise<void>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not get this from the onboarding controller state?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately onboarding controller only exists on extension. On mobile the onboarding state is stored in Redux only, so we'll have to come up with a way to monitor that.

};

type AddSnapArgs = {
Expand Down Expand Up @@ -932,6 +939,8 @@ export class SnapController extends BaseController<

readonly #trackSnapExport: ReturnType<typeof throttleTracking>;

readonly #ensureOnboardingComplete: () => Promise<void>;

constructor({
closeAllConnections,
messenger,
Expand All @@ -951,6 +960,7 @@ export class SnapController extends BaseController<
getFeatureFlags = () => ({}),
clientCryptography,
trackEvent,
ensureOnboardingComplete,
}: SnapControllerArgs) {
super({
messenger,
Expand Down Expand Up @@ -1034,6 +1044,7 @@ export class SnapController extends BaseController<
this.#rollbackSnapshots = new Map();
this.#snapsRuntimeData = new Map();
this.#trackEvent = trackEvent;
this.#ensureOnboardingComplete = ensureOnboardingComplete;

this.#pollForLastRequestStatus();

Expand Down Expand Up @@ -1467,7 +1478,7 @@ export class SnapController extends BaseController<
* Also updates any preinstalled Snaps to the latest allowlisted version.
*/
async updateRegistry(): Promise<void> {
this.#assertCanUsePlatform();
await this.#assertCanUsePlatform();
await this.messenger.call('SnapsRegistry:update');

const blockedSnaps = await this.messenger.call(
Expand Down Expand Up @@ -1645,9 +1656,12 @@ export class SnapController extends BaseController<
}

/**
* Asserts whether the Snaps platform is allowed to run.
* Waits for onboarding and then asserts whether the Snaps platform is allowed to run.
*/
#assertCanUsePlatform() {
async #assertCanUsePlatform() {
// Ensure the user has onboarded before allowing access to Snaps.
await this.#ensureOnboardingComplete();

const flags = this.#getFeatureFlags();
assert(
flags.disableSnaps !== true,
Expand Down Expand Up @@ -1730,7 +1744,7 @@ export class SnapController extends BaseController<
* @param snapId - The id of the Snap to start.
*/
async startSnap(snapId: SnapId): Promise<void> {
this.#assertCanUsePlatform();
await this.#assertCanUsePlatform();
const snap = this.state.snaps[snapId];

if (!snap.enabled) {
Expand Down Expand Up @@ -2630,7 +2644,7 @@ export class SnapController extends BaseController<
origin: string,
requestedSnaps: RequestSnapsParams,
): Promise<RequestSnapsResult> {
this.#assertCanUsePlatform();
await this.#assertCanUsePlatform();

const result: RequestSnapsResult = {};

Expand Down Expand Up @@ -2916,7 +2930,7 @@ export class SnapController extends BaseController<
if (!automaticUpdate) {
this.#assertCanInstallSnaps();
}
this.#assertCanUsePlatform();
await this.#assertCanUsePlatform();

const snap = this.getExpect(snapId);

Expand Down Expand Up @@ -3551,7 +3565,7 @@ export class SnapController extends BaseController<
handler: handlerType,
request: rawRequest,
}: SnapRpcHookArgs & { snapId: SnapId }): Promise<unknown> {
this.#assertCanUsePlatform();
await this.#assertCanUsePlatform();

const snap = this.get(snapId);

Expand Down
4 changes: 2 additions & 2 deletions packages/snaps-controllers/src/test-utils/controller.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ export const getSnapControllerMessenger = (
'ExecutionService:executeSnap',
'ExecutionService:terminateSnap',
'ExecutionService:handleRpcRequest',
'NetworkController:getNetworkClientById',
'PermissionController:getEndowments',
'PermissionController:hasPermission',
'PermissionController:hasPermissions',
Expand All @@ -490,7 +489,6 @@ export const getSnapControllerMessenger = (
'PermissionController:revokePermissionForAllSubjects',
'PermissionController:updateCaveat',
'PermissionController:getSubjectNames',
'SelectedNetworkController:getNetworkClientIdForDomain',
'SubjectMetadataController:getSubjectMetadata',
'SubjectMetadataController:addSubjectMetadata',
'SnapsRegistry:get',
Expand Down Expand Up @@ -574,6 +572,7 @@ export const getSnapControllerOptions = (
clientCryptography: {},
encryptor: getSnapControllerEncryptor(),
trackEvent: jest.fn(),
ensureOnboardingComplete: jest.fn().mockResolvedValue(undefined),
...opts,
} as SnapControllerConstructorParams;

Expand Down Expand Up @@ -608,6 +607,7 @@ export const getSnapControllerWithEESOptions = ({
encryptor: getSnapControllerEncryptor(),
fetchFunction: jest.fn(),
trackEvent: jest.fn(),
ensureOnboardingComplete: jest.fn().mockResolvedValue(undefined),
...options,
} as SnapControllerConstructorParams & {
rootMessenger: ReturnType<typeof getControllerMessenger>;
Expand Down