diff --git a/contracts/account/modules/ERC7579SocialRecovery.sol b/contracts/account/modules/ERC7579SocialRecovery.sol new file mode 100644 index 00000000..bc49b915 --- /dev/null +++ b/contracts/account/modules/ERC7579SocialRecovery.sol @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.27; + +import {IERC7579Module, MODULE_TYPE_EXECUTOR, MODULE_TYPE_FALLBACK} from "@openzeppelin/contracts/interfaces/draft-IERC7579.sol"; +import {EIP712} from "@openzeppelin/contracts/utils/cryptography/EIP712.sol"; +import {SignatureChecker} from "@openzeppelin/contracts/utils/cryptography/SignatureChecker.sol"; +import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol"; +import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import {Address} from "@openzeppelin/contracts/utils/Address.sol"; +import {Nonces} from "@openzeppelin/contracts/utils/Nonces.sol"; +import {ERC7913Utils} from "../../utils/cryptography/ERC7913Utils.sol"; +import {EnumerableMapExtended} from "../../utils/structs/EnumerableMapExtended.sol"; + +abstract contract ERC7579SocialRecovery is EIP712, Nonces, IERC7579Module { + using Checkpoints for *; + using EnumerableMapExtended for *; + using SafeCast for *; + + bytes32 private constant START_RECOVERY_TYPEHASH = + keccak256("StartRecovery(address account,bytes recovery,uint256 nonce)"); + bytes32 private constant CANCEL_RECOVERY_TYPEHASH = keccak256("CancelRecovery(address account,uint256 nonce)"); + + struct Permission { + bytes signer; + bytes signature; + } + + struct ThresholdConfig { + uint64 threshold; // Threshold value + uint48 lockPeriod; // Lock period for the threshold + } + + struct AccountConfig { + EnumerableMapExtended.BytesToUintMap guardians; + Checkpoints.Trace160 thresholds; + uint48 expiryTime; + bytes recoveryCall; + } + + mapping(address account => AccountConfig) private _configs; + + event RecoveryConfigSet(address indexed account, bytes[] guardians, ThresholdConfig[] thresholds); + event RecoveryConfigCleared(address indexed account); + event RecoveryStarted(address indexed account, bytes recoveryCall, uint48 expiryTime); + event RecoveryExecuted(address indexed account, bytes recoveryCall); + event RecoveryCanceled(address indexed account); + error InvalidGuardian(address account, bytes signer); + error InvalidSignature(address account, bytes signer); + error PolicyVerificationFailed(address account); + error ThresholdNotReached(address account, uint64 weight); + error AccountNotInRecovery(address account); + error AccountRecoveryPending(address account); + error AccountRecoveryNotReady(address account); + + modifier onlyNotRecovering(address account) { + require(_configs[account].expiryTime == 0, AccountRecoveryPending(account)); + _; + } + + modifier onlyRecovering(address account) { + require(_configs[account].expiryTime != 0, AccountNotInRecovery(account)); + _; + } + + modifier onlyRecoveryReady(address account) { + uint48 expiryTime = _configs[account].expiryTime; + require(expiryTime != 0 && expiryTime <= block.timestamp, AccountRecoveryNotReady(account)); + _; + } + + /**************************************************************************************************************** + * IERC7579Module * + ****************************************************************************************************************/ + function onInstall(bytes calldata data) public virtual { + if (data.length > 0) { + Address.functionDelegateCall(address(this), data); + } + } + + function onUninstall(bytes calldata /*data*/) public virtual { + _clearConfig(msg.sender); + } + + function isModuleType(uint256 moduleTypeId) public view virtual returns (bool) { + return moduleTypeId == MODULE_TYPE_EXECUTOR || moduleTypeId == MODULE_TYPE_FALLBACK; + } + + /**************************************************************************************************************** + * Social recovery - Core * + ****************************************************************************************************************/ + function updateGuardians(bytes[] calldata guardians, ThresholdConfig[] calldata thresholds) public virtual { + _overrideConfig(msg.sender, guardians, thresholds); + } + + function startRecovery(bytes calldata recoveryCall, Permission[] calldata permissions) external { + startRecovery(msg.sender, recoveryCall, permissions); + } + + function startRecovery( + address account, + bytes calldata recoveryCall, + Permission[] calldata permissions + ) public virtual { + uint48 lockPeriod = _checkPermissions( + account, + _hashTypedDataV4( + keccak256(abi.encode(START_RECOVERY_TYPEHASH, account, keccak256(recoveryCall), _useNonce(account))) + ), + permissions + ); + _startRecovery(account, recoveryCall, lockPeriod); + } + + function executeRecovery() external { + executeRecovery(msg.sender); + } + + function executeRecovery(address account) public virtual { + _executeRecovery(account); + } + + function cancelRecovery() public virtual { + _cancelRecovery(msg.sender); + } + + function cancelRecoveryByGuardians(Permission[] calldata permissions) external { + cancelRecoveryByGuardians(msg.sender, permissions); + } + + function cancelRecoveryByGuardians(address account, Permission[] calldata permissions) public virtual { + _checkPermissions( + account, + _hashTypedDataV4(keccak256(abi.encode(CANCEL_RECOVERY_TYPEHASH, account, _useNonce(account)))), + permissions + ); + _cancelRecovery(account); + } + + function isGuardian(bytes calldata guardian) external view returns (bool exist, uint64 weight) { + return isGuardian(msg.sender, guardian); + } + + function isGuardian( + address account, + bytes calldata guardian + ) public view virtual returns (bool exist, uint64 weight) { + (bool exist_, uint256 weight_) = _configs[account].guardians.tryGet(guardian); + return (exist_, weight_.toUint48()); + } + + function getAccountConfigs() external view returns (bytes[] memory guardians, ThresholdConfig[] memory thresholds) { + return getAccountConfigs(msg.sender); + } + + function getAccountConfigs( + address account + ) public view virtual returns (bytes[] memory guardians, ThresholdConfig[] memory thresholds) { + AccountConfig storage config = _configs[account]; + + guardians = new bytes[](config.guardians.length()); + for (uint256 i = 0; i < guardians.length; ++i) { + (bytes memory signer, uint256 weight) = config.guardians.at(i); + guardians[i] = _formatGuardian(weight.toUint64(), signer); + } + + thresholds = new ThresholdConfig[](config.thresholds.length()); + for (uint256 i = 0; i < thresholds.length; ++i) { + Checkpoints.Checkpoint160 memory ckpt = config.thresholds.at(i.toUint32()); + thresholds[i].threshold = ckpt._key.toUint64(); + thresholds[i].lockPeriod = ckpt._value.toUint48(); + } + } + + function getRecoveryNonce() external view returns (uint256 nonce) { + return nonces(msg.sender); + } + + function getRecoveryStatus() external view returns (bool isRecovering, uint48 expiryTime) { + return getRecoveryStatus(msg.sender); + } + + function getRecoveryStatus(address account) public view virtual returns (bool isRecovering, uint48 expiryTime) { + expiryTime = _configs[account].expiryTime; + isRecovering = expiryTime != 0; + } + + /**************************************************************************************************************** + * Social recovery - internal * + ****************************************************************************************************************/ + function _overrideConfig( + address account, + bytes[] calldata guardians, + ThresholdConfig[] calldata thresholds + ) internal virtual { + _clearConfig(account); + + AccountConfig storage config = _configs[account]; + + // guardians + for (uint256 i = 0; i < guardians.length; ++i) { + (uint64 weight, bytes calldata signer) = _parseGuardian(guardians[i]); + config.guardians.set(signer, weight); + } + + // threshold + for (uint256 i = 0; i < thresholds.length; ++i) { + config.thresholds.push(thresholds[i].threshold, thresholds[i].lockPeriod); + } + + emit RecoveryConfigSet(account, guardians, thresholds); + } + + function _clearConfig(address account) internal virtual { + AccountConfig storage config = _configs[account]; + + // clear enumerable map + config.guardians.clear(); + + // clear threshold + Checkpoints.Checkpoint160[] storage ckpts = config.thresholds._checkpoints; + assembly ("memory-safe") { + sstore(ckpts.slot, 0) + } + + // clear remaining + delete config.expiryTime; + delete config.recoveryCall; + + emit RecoveryConfigCleared(account); + } + + function _checkPermissions( + address account, + bytes32 hash, + Permission[] calldata permissions + ) internal virtual returns (uint48) { + AccountConfig storage config = _configs[account]; + + // verify signature and get properties + uint64 totalWeight = 0; + bytes32 previousHash = bytes32(0); + for (uint256 i = 0; i < permissions.length; ++i) { + // uniqueness + bytes32 newHash = keccak256(permissions[i].signer); + require(previousHash < newHash, PolicyVerificationFailed(account)); + previousHash = newHash; + + // validity of signer and signature + (bool validIdentity, uint64 weight) = isGuardian(account, permissions[i].signer); + require(validIdentity, InvalidGuardian(account, permissions[i].signer)); + require( + ERC7913Utils.isValidSignatureNow(permissions[i].signer, hash, permissions[i].signature), + InvalidSignature(account, permissions[i].signer) + ); + + // total weight + totalWeight += weight; + } + + // get lock period + uint48 lockPeriod = config.thresholds.upperLookup(totalWeight).toUint48(); // uint160 -> uint48 + + // TODO: case where the delay really is zero vs case where there is no delay? + require(lockPeriod > 0, ThresholdNotReached(account, totalWeight)); + + return lockPeriod; + } + + function _startRecovery( + address account, + bytes calldata recoveryCall, + uint48 lockPeriod + ) internal virtual onlyNotRecovering(account) { + uint48 expiryTime = SafeCast.toUint48(block.timestamp + lockPeriod); + + // set recovery details + _configs[account].expiryTime = expiryTime; + _configs[account].recoveryCall = recoveryCall; + + emit RecoveryStarted(account, recoveryCall, expiryTime); + } + + function _executeRecovery(address account) internal virtual onlyRecoveryReady(account) { + // cache + bytes memory recoveryCall = _configs[account].recoveryCall; + + // clean (prevents reentry) + delete _configs[account].expiryTime; + delete _configs[account].recoveryCall; + + // perform call + Address.functionCall(account, recoveryCall); + + emit RecoveryExecuted(account, recoveryCall); + } + + function _cancelRecovery(address account) internal virtual onlyRecovering(account) { + // clean + delete _configs[account].expiryTime; + delete _configs[account].recoveryCall; + + emit RecoveryCanceled(account); + } + + /**************************************************************************************************************** + * Helpers * + ****************************************************************************************************************/ + function _formatGuardian(uint64 weight, bytes memory signer) internal pure virtual returns (bytes memory guardian) { + return abi.encodePacked(weight, signer); + } + + function _parseGuardian( + bytes calldata guardian + ) internal pure virtual returns (uint64 weight, bytes calldata signer) { + weight = uint64(bytes8(guardian[0:8])); + signer = guardian[8:]; + } +} diff --git a/contracts/mocks/account/modules/ERC7579ValidatorMock.sol b/contracts/mocks/account/modules/ERC7579ValidatorMock.sol index 53040b3f..69ae7218 100644 --- a/contracts/mocks/account/modules/ERC7579ValidatorMock.sol +++ b/contracts/mocks/account/modules/ERC7579ValidatorMock.sol @@ -42,4 +42,13 @@ abstract contract ERC7579ValidatorMock is ERC7579ModuleMock(MODULE_TYPE_VALIDATO ? IERC1271.isValidSignature.selector : bytes4(0xffffffff); } + + function getSigner(address account) public view virtual returns (address) { + return _associatedSigners[account]; + } + + function updateSigner(address newSigner) public virtual { + require(newSigner != address(0) && _associatedSigners[msg.sender] != address(0)); + _associatedSigners[msg.sender] = newSigner; + } } diff --git a/test/account/modules/ERC7579SocialRecovery.test.js b/test/account/modules/ERC7579SocialRecovery.test.js new file mode 100644 index 00000000..d7bda891 --- /dev/null +++ b/test/account/modules/ERC7579SocialRecovery.test.js @@ -0,0 +1,166 @@ +const { ethers, entrypoint } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); +const { anyValue } = require('@nomicfoundation/hardhat-chai-matchers/withArgs'); + +const { getDomain } = require('@openzeppelin/contracts/test/helpers/eip712'); +const { + encodeMode, + encodeSingle, + CALL_TYPE_CALL, + MODULE_TYPE_EXECUTOR, +} = require('@openzeppelin/contracts/test/helpers/erc7579'); +const time = require('@openzeppelin/contracts/test/helpers/time'); + +const { ERC4337Helper } = require('../../helpers/erc4337'); +const { PackedUserOperation, StartRecovery } = require('../../helpers/eip712-types'); + +const comp = (a, b) => a > b || -(a < b); + +async function fixture() { + // EOAs and environment + const [other, ...accounts] = await ethers.getSigners(); + + // ERC-7579 validator + const validator = await ethers.deployContract('$ERC7579ValidatorMock'); + const socialRecovery = await ethers.deployContract('$ERC7579SocialRecovery', ['SocialRecovery', '1']); + + // ERC-4337 signer + const signer = ethers.Wallet.createRandom(); + + // ERC-4337 account + const helper = new ERC4337Helper(); + const env = await helper.wait(); + const mock = await helper.newAccount('$AccountERC7579Mock', [ + validator, + ethers.solidityPacked(['address'], [signer.address]), + ]); + + // ERC-4337 Entrypoint domain + const entrypointDomain = await getDomain(entrypoint.v08); + + return { ...env, validator, socialRecovery, entrypointDomain, mock, signer, other, accounts }; +} + +describe('ERC7579SocialRecovery', function () { + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + + this.userOp = { nonce: ethers.zeroPadBytes(ethers.hexlify(this.validator.target), 32) }; + this.signUserOp = (userOp, signer = this.signer) => + signer + .signTypedData(this.entrypointDomain, { PackedUserOperation }, userOp.packed) + .then(signature => Object.assign(userOp, { signature })); + }); + + describe('with guardians', function () { + beforeEach(async function () { + await this.other.sendTransaction({ to: this.mock.target, value: ethers.parseEther('1') }); + await this.mock.deploy(); + + this.guardians = [ + { signer: this.accounts[0], weight: 1n }, + { signer: this.accounts[1], weight: 1n }, + { signer: this.accounts[2], weight: 1n }, + ] + .map(entry => + Object.assign(entry, { + erc7913signer: ethers.solidityPacked( + ['address', 'bytes'], + [entry.signer.target ?? entry.signer.address ?? entry.signer, entry.key ?? '0x'], + ), + guardian: ethers.solidityPacked( + ['uint64', 'address', 'bytes'], + [entry.weight, entry.signer.target ?? entry.signer.address ?? entry.signer, entry.key ?? '0x'], + ), + }), + ) + .sort((g1, g2) => + comp( + ethers.toBigInt(ethers.keccak256(ethers.getBytes(g1.erc7913signer))), + ethers.toBigInt(ethers.keccak256(ethers.getBytes(g2.erc7913signer))), + ), + ); + + this.thresholds = [ + { threshold: 2n, lockPeriod: time.duration.days(7n) }, + { threshold: 3n, lockPeriod: time.duration.hours(1n) }, + ]; + }); + + it('workflow', async function () { + // install + await this.mock + .createUserOp({ + ...this.userOp, + target: this.mock.target, + callData: this.mock.interface.encodeFunctionData('installModule', [ + MODULE_TYPE_EXECUTOR, + this.socialRecovery.target, + this.socialRecovery.interface.encodeFunctionData('updateGuardians', [ + this.guardians.map(({ guardian }) => guardian), + this.thresholds, + ]), + ]), + callGas: 1_000_000n, // TODO: estimate ? + }) + .then(op => this.signUserOp(op)) + .then(op => entrypoint.v08.handleOps([op.packed], this.other)); + + // check config + await expect(this.socialRecovery.getAccountConfigs(ethers.Typed.address(this.mock))).to.eventually.deep.equal([ + this.guardians.map(({ guardian }) => guardian), + this.thresholds.map(Object.values), + ]); + + // prepare recovery + const socialRecoveryMessage = { + account: this.mock.target, + recovery: this.mock.interface.encodeFunctionData('executeFromExecutor', [ + encodeMode({ callType: CALL_TYPE_CALL }), + encodeSingle( + this.validator, + 0n, + this.validator.interface.encodeFunctionData('updateSigner', [this.other.address]), + ), + ]), + nonce: await this.socialRecovery.nonces(this.mock), + }; + + const signatures = await getDomain(this.socialRecovery).then(domain => + Promise.all( + this.guardians.map(({ signer, erc7913signer }) => + signer + .signTypedData(domain, { StartRecovery }, socialRecoveryMessage) + .then(signature => ({ signer: erc7913signer, signature })), + ), + ), + ); + + // start recovery + await expect( + this.socialRecovery.startRecovery( + ethers.Typed.address(socialRecoveryMessage.account), + socialRecoveryMessage.recovery, + signatures, + ), + ) + .to.emit(this.socialRecovery, 'RecoveryStarted') + .withArgs(this.mock, socialRecoveryMessage.recovery, anyValue); + + // wait + await time.increaseBy.timestamp(time.duration.hours(1n)); + + // signer before the recovery + await expect(this.validator.getSigner(this.mock)).to.eventually.equal(this.signer); + + // execute + await expect(this.socialRecovery.executeRecovery(ethers.Typed.address(socialRecoveryMessage.account))) + .to.emit(this.socialRecovery, 'RecoveryExecuted') + .withArgs(this.mock, socialRecoveryMessage.recovery); + + // signer after the recovery + await expect(this.validator.getSigner(this.mock)).to.eventually.equal(this.other); + }); + }); +}); diff --git a/test/helpers/eip712-types.js b/test/helpers/eip712-types.js index ddcfa303..e90dca9c 100644 --- a/test/helpers/eip712-types.js +++ b/test/helpers/eip712-types.js @@ -26,6 +26,16 @@ module.exports = mapValues( validAfter: 'uint48', validUntil: 'uint48', }, + StartRecovery: { + account: 'address', + recovery: 'bytes', + nonce: 'uint256', + }, + CancelRecovery: { + account: 'address', + recovery: 'bytes', + nonce: 'uint256', + }, }, formatType, );