@@ -52,6 +52,7 @@ import {
5252 MOCK_SNAP_NAME ,
5353 DEFAULT_SOURCE_PATH ,
5454 DEFAULT_ICON_PATH ,
55+ TEST_SECRET_RECOVERY_PHRASE_BYTES ,
5556} from '@metamask/snaps-utils/test-utils' ;
5657import type { SemVerRange , SemVerVersion , Json } from '@metamask/utils' ;
5758import {
@@ -60,6 +61,7 @@ import {
6061 AssertionError ,
6162 base64ToBytes ,
6263 stringToBytes ,
64+ createDeferredPromise ,
6365} from '@metamask/utils' ;
6466import { File } from 'buffer' ;
6567import { webcrypto } from 'crypto' ;
@@ -78,6 +80,7 @@ import {
7880 getNodeEESMessenger ,
7981 getPersistedSnapsState ,
8082 getSnapController ,
83+ getSnapControllerEncryptor ,
8184 getSnapControllerMessenger ,
8285 getSnapControllerOptions ,
8386 getSnapControllerWithEES ,
@@ -97,6 +100,7 @@ import {
97100 MOCK_WALLET_SNAP_PERMISSION ,
98101 MockSnapsRegistry ,
99102 sleep ,
103+ waitForStateChange ,
100104} from '../test-utils' ;
101105import { delay } from '../utils' ;
102106import { LEGACY_ENCRYPTION_KEY_DERIVATION_OPTIONS } from './constants' ;
@@ -2117,6 +2121,59 @@ describe('SnapController', () => {
21172121 await service . terminateAllSnaps ( ) ;
21182122 } ) ;
21192123
2124+ it ( 'clears encrypted state of Snaps when the client is locked' , async ( ) => {
2125+ const rootMessenger = getControllerMessenger ( ) ;
2126+ const messenger = getSnapControllerMessenger ( rootMessenger ) ;
2127+
2128+ const state = { myVariable : 1 } ;
2129+
2130+ const mockEncryptedState = await encrypt (
2131+ ENCRYPTION_KEY ,
2132+ state ,
2133+ undefined ,
2134+ undefined ,
2135+ DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS ,
2136+ ) ;
2137+
2138+ const getMnemonic = jest
2139+ . fn ( )
2140+ . mockReturnValue ( TEST_SECRET_RECOVERY_PHRASE_BYTES ) ;
2141+
2142+ const snapController = getSnapController (
2143+ getSnapControllerOptions ( {
2144+ messenger,
2145+ state : {
2146+ snaps : {
2147+ [ MOCK_SNAP_ID ] : getPersistedSnapObject ( ) ,
2148+ } ,
2149+ snapStates : {
2150+ [ MOCK_SNAP_ID ] : mockEncryptedState ,
2151+ } ,
2152+ } ,
2153+ getMnemonic,
2154+ } ) ,
2155+ ) ;
2156+
2157+ expect (
2158+ await messenger . call ( 'SnapController:getSnapState' , MOCK_SNAP_ID , true ) ,
2159+ ) . toStrictEqual ( state ) ;
2160+ expect ( getMnemonic ) . toHaveBeenCalledTimes ( 1 ) ;
2161+
2162+ rootMessenger . publish ( 'KeyringController:lock' ) ;
2163+
2164+ expect (
2165+ await messenger . call ( 'SnapController:getSnapState' , MOCK_SNAP_ID , true ) ,
2166+ ) . toStrictEqual ( state ) ;
2167+
2168+ // We assume `getMnemonic` is called again because the controller needs to
2169+ // decrypt the state again. This is not an ideal way to test this, but it
2170+ // is the easiest to test this without exposing the internal state of the
2171+ // `SnapController`.
2172+ expect ( getMnemonic ) . toHaveBeenCalledTimes ( 2 ) ;
2173+
2174+ snapController . destroy ( ) ;
2175+ } ) ;
2176+
21202177 describe ( 'handleRequest' , ( ) => {
21212178 it . each (
21222179 Object . keys ( handlerEndowments ) . filter (
@@ -8801,6 +8858,7 @@ describe('SnapController', () => {
88018858 ) ;
88028859
88038860 const newState = { myVariable : 2 } ;
8861+ const promise = waitForStateChange ( messenger ) ;
88048862
88058863 await messenger . call (
88068864 'SnapController:updateSnapState' ,
@@ -8817,6 +8875,8 @@ describe('SnapController', () => {
88178875 DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS ,
88188876 ) ;
88198877
8878+ await promise ;
8879+
88208880 const result = await messenger . call (
88218881 'SnapController:getSnapState' ,
88228882 MOCK_SNAP_ID ,
@@ -8831,7 +8891,7 @@ describe('SnapController', () => {
88318891 snapController . destroy ( ) ;
88328892 } ) ;
88338893
8834- it ( 'different snaps use different encryption keys' , async ( ) => {
8894+ it ( 'uses different encryption keys for different snaps ' , async ( ) => {
88358895 const messenger = getSnapControllerMessenger ( ) ;
88368896
88378897 const state = { foo : 'bar' } ;
@@ -8857,13 +8917,17 @@ describe('SnapController', () => {
88578917 true ,
88588918 ) ;
88598919
8920+ const promise = waitForStateChange ( messenger ) ;
8921+
88608922 await messenger . call (
88618923 'SnapController:updateSnapState' ,
88628924 MOCK_LOCAL_SNAP_ID ,
88638925 state ,
88648926 true ,
88658927 ) ;
88668928
8929+ await promise ;
8930+
88678931 const encryptedState1 = await encrypt (
88688932 ENCRYPTION_KEY ,
88698933 state ,
@@ -9073,13 +9137,17 @@ describe('SnapController', () => {
90739137 undefined ,
90749138 DEFAULT_ENCRYPTION_KEY_DERIVATION_OPTIONS ,
90759139 ) ;
9140+
9141+ const promise = waitForStateChange ( messenger ) ;
90769142 await messenger . call (
90779143 'SnapController:updateSnapState' ,
90789144 MOCK_SNAP_ID ,
90799145 state ,
90809146 true ,
90819147 ) ;
90829148
9149+ await promise ;
9150+
90839151 expect ( updateSnapStateSpy ) . toHaveBeenCalledTimes ( 1 ) ;
90849152 expect ( snapController . state . snapStates [ MOCK_SNAP_ID ] ) . toStrictEqual (
90859153 mockEncryptedState ,
@@ -9137,17 +9205,126 @@ describe('SnapController', () => {
91379205 ) ;
91389206
91399207 const state = { foo : 'bar' } ;
9208+
9209+ const promise = waitForStateChange ( messenger ) ;
91409210 await messenger . call (
91419211 'SnapController:updateSnapState' ,
91429212 MOCK_SNAP_ID ,
91439213 state ,
91449214 true ,
91459215 ) ;
91469216
9217+ await promise ;
9218+
91479219 expect ( pbkdf2Sha512 ) . toHaveBeenCalledTimes ( 1 ) ;
91489220
91499221 snapController . destroy ( ) ;
91509222 } ) ;
9223+
9224+ it ( 'queues multiple state updates' , async ( ) => {
9225+ const messenger = getSnapControllerMessenger ( ) ;
9226+
9227+ jest . useFakeTimers ( ) ;
9228+
9229+ const encryptor = getSnapControllerEncryptor ( ) ;
9230+ const { promise, resolve } = createDeferredPromise ( ) ;
9231+ const encryptWithKey = jest
9232+ . fn <
9233+ ReturnType < typeof encryptor . encryptWithKey > ,
9234+ Parameters < typeof encryptor . encryptWithKey >
9235+ > ( )
9236+ . mockImplementation ( async ( ...args ) => {
9237+ resolve ( ) ;
9238+ await sleep ( 1 ) ;
9239+ return await encryptor . encryptWithKey ( ...args ) ;
9240+ } ) ;
9241+
9242+ const snapController = getSnapController (
9243+ getSnapControllerOptions ( {
9244+ messenger,
9245+ state : {
9246+ snaps : getPersistedSnapsState ( ) ,
9247+ } ,
9248+ encryptor : {
9249+ ...getSnapControllerEncryptor ( ) ,
9250+ // @ts -expect-error - Missing required properties.
9251+ encryptWithKey,
9252+ } ,
9253+ } ) ,
9254+ ) ;
9255+
9256+ const firstStateChange = waitForStateChange ( messenger ) ;
9257+ await messenger . call (
9258+ 'SnapController:updateSnapState' ,
9259+ MOCK_SNAP_ID ,
9260+ { foo : 'bar' } ,
9261+ true ,
9262+ ) ;
9263+
9264+ await messenger . call (
9265+ 'SnapController:updateSnapState' ,
9266+ MOCK_SNAP_ID ,
9267+ { bar : 'baz' } ,
9268+ true ,
9269+ ) ;
9270+
9271+ // We await this promise to ensure the timer is queued.
9272+ await promise ;
9273+ jest . advanceTimersByTime ( 1 ) ;
9274+
9275+ // After this point the second update should be queued.
9276+ await firstStateChange ;
9277+ const secondStateChange = waitForStateChange ( messenger ) ;
9278+
9279+ expect ( encryptWithKey ) . toHaveBeenCalledTimes ( 1 ) ;
9280+
9281+ // This is a bit hacky, but we can't simply advance the timer by 1ms
9282+ // because the second timer is not running yet.
9283+ jest . useRealTimers ( ) ;
9284+ await secondStateChange ;
9285+
9286+ expect ( encryptWithKey ) . toHaveBeenCalledTimes ( 2 ) ;
9287+
9288+ expect (
9289+ await messenger . call ( 'SnapController:getSnapState' , MOCK_SNAP_ID , true ) ,
9290+ ) . toStrictEqual ( { bar : 'baz' } ) ;
9291+
9292+ snapController . destroy ( ) ;
9293+ } ) ;
9294+
9295+ it ( 'logs an error message if the state fails to persist' , async ( ) => {
9296+ const messenger = getSnapControllerMessenger ( ) ;
9297+
9298+ const errorValue = new Error ( 'Failed to persist state.' ) ;
9299+ const snapController = getSnapController (
9300+ getSnapControllerOptions ( {
9301+ messenger,
9302+ state : {
9303+ snaps : getPersistedSnapsState ( ) ,
9304+ } ,
9305+ // @ts -expect-error - Missing required properties.
9306+ encryptor : {
9307+ ...getSnapControllerEncryptor ( ) ,
9308+ encryptWithKey : jest . fn ( ) . mockRejectedValue ( errorValue ) ,
9309+ } ,
9310+ } ) ,
9311+ ) ;
9312+
9313+ const { promise, resolve } = createDeferredPromise ( ) ;
9314+ const error = jest . spyOn ( console , 'error' ) . mockImplementation ( resolve ) ;
9315+
9316+ await messenger . call (
9317+ 'SnapController:updateSnapState' ,
9318+ MOCK_SNAP_ID ,
9319+ { foo : 'bar' } ,
9320+ true ,
9321+ ) ;
9322+
9323+ await promise ;
9324+ expect ( error ) . toHaveBeenCalledWith ( errorValue ) ;
9325+
9326+ snapController . destroy ( ) ;
9327+ } ) ;
91519328 } ) ;
91529329
91539330 describe ( 'SnapController:clearSnapState' , ( ) => {
@@ -9206,6 +9383,41 @@ describe('SnapController', () => {
92069383
92079384 snapController . destroy ( ) ;
92089385 } ) ;
9386+
9387+ it ( 'logs an error message if the state fails to persist' , async ( ) => {
9388+ const messenger = getSnapControllerMessenger ( ) ;
9389+
9390+ const errorValue = new Error ( 'Failed to persist state.' ) ;
9391+ const snapController = getSnapController (
9392+ getSnapControllerOptions ( {
9393+ messenger,
9394+ state : {
9395+ snaps : getPersistedSnapsState ( ) ,
9396+ } ,
9397+ // @ts -expect-error - Missing required properties.
9398+ encryptor : {
9399+ ...getSnapControllerEncryptor ( ) ,
9400+ encryptWithKey : jest . fn ( ) . mockRejectedValue ( errorValue ) ,
9401+ } ,
9402+ } ) ,
9403+ ) ;
9404+
9405+ const { promise, resolve } = createDeferredPromise ( ) ;
9406+ const error = jest . spyOn ( console , 'error' ) . mockImplementation ( resolve ) ;
9407+
9408+ // @ts -expect-error - Property `update` is protected.
9409+ // eslint-disable-next-line jest/prefer-spy-on
9410+ snapController . update = jest . fn ( ) . mockImplementation ( ( ) => {
9411+ throw errorValue ;
9412+ } ) ;
9413+
9414+ await messenger . call ( 'SnapController:clearSnapState' , MOCK_SNAP_ID , true ) ;
9415+
9416+ await promise ;
9417+ expect ( error ) . toHaveBeenCalledWith ( errorValue ) ;
9418+
9419+ snapController . destroy ( ) ;
9420+ } ) ;
92099421 } ) ;
92109422
92119423 describe ( 'SnapController:updateBlockedSnaps' , ( ) => {
0 commit comments