diff --git a/contracts/mocks/DispatchModuleMock.sol b/contracts/mocks/DispatchModuleMock.sol new file mode 100644 index 00000000..4c7f70d4 --- /dev/null +++ b/contracts/mocks/DispatchModuleMock.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +contract DispatchModuleMock { + event Call(address sender, uint256 value, bytes data); + + fallback() external payable { + emit Call(msg.sender, msg.value, msg.data); + } +} diff --git a/contracts/mocks/import.sol b/contracts/mocks/import.sol index 7c5047ed..3cbe0bc6 100644 --- a/contracts/mocks/import.sol +++ b/contracts/mocks/import.sol @@ -14,5 +14,6 @@ import { } from "@openzeppelin/contracts/mocks/account/AccountMock.sol"; import {ERC1271WalletMock} from "@openzeppelin/contracts/mocks/ERC1271WalletMock.sol"; import {CallReceiverMock} from "@openzeppelin/contracts/mocks/CallReceiverMock.sol"; +import {EtherReceiverMock} from "@openzeppelin/contracts/mocks/EtherReceiverMock.sol"; import {ERC7913P256Verifier} from "@openzeppelin/contracts/utils/cryptography/verifiers/ERC7913P256Verifier.sol"; import {ERC7913RSAVerifier} from "@openzeppelin/contracts/utils/cryptography/verifiers/ERC7913RSAVerifier.sol"; diff --git a/contracts/proxy/dispatch/DispatchProxy.sol b/contracts/proxy/dispatch/DispatchProxy.sol new file mode 100644 index 00000000..543b3d6f --- /dev/null +++ b/contracts/proxy/dispatch/DispatchProxy.sol @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; +import {DispatchUpdateModule} from "./modules/DispatchUpdateModule.sol"; +import {Dispatch} from "./utils/Dispatch.sol"; + +/** + * @title DispatchProxy + * @dev TODO + */ +contract DispatchProxy is Proxy { + using Dispatch for Dispatch.VMT; + + bytes4 private constant _FALLBACK_SIG = 0xffffffff; + + error DispatchProxyMissingImplementation(bytes4 selector); + + constructor(address updateFacet, address initialOwner) { + Dispatch.VMT storage store = Dispatch.instance(); + store.setOwner(initialOwner); + store.setFunction(DispatchUpdateModule.updateDispatchTable.selector, updateFacet); + } + + function _implementation() internal view virtual override returns (address module) { + Dispatch.VMT storage store = Dispatch.instance(); + + module = store.getFunction(msg.sig); + if (module != address(0)) return module; + + module = store.getFunction(_FALLBACK_SIG); + if (module != address(0)) return module; + + revert DispatchProxyMissingImplementation(msg.sig); + } +} diff --git a/contracts/proxy/dispatch/interfaces/IDiamondCut.sol b/contracts/proxy/dispatch/interfaces/IDiamondCut.sol new file mode 100644 index 00000000..a8a9cfbc --- /dev/null +++ b/contracts/proxy/dispatch/interfaces/IDiamondCut.sol @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +interface IDiamondCut { + enum FacetCutAction { + Add, + Replace, + Remove + } + + struct FacetCut { + address facetAddress; + FacetCutAction action; + bytes4[] functionSelectors; + } + + event DiamondCut(FacetCut[] _diamondCut, address _init, bytes _calldata); + + /// @notice Add/replace/remove any number of functions and optionally execute + /// a function with delegatecall + /// @param _diamondCut Contains the facet addresses and function selectors + /// @param _init The address of the contract or facet to execute _calldata + /// @param _calldata A function call, including function selector and arguments + /// _calldata is executed with delegatecall on _init + function diamondCut(FacetCut[] calldata _diamondCut, address _init, bytes calldata _calldata) external; +} diff --git a/contracts/proxy/dispatch/interfaces/IDiamondLoupe.sol b/contracts/proxy/dispatch/interfaces/IDiamondLoupe.sol new file mode 100644 index 00000000..1ca90b4a --- /dev/null +++ b/contracts/proxy/dispatch/interfaces/IDiamondLoupe.sol @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +interface IDiamondLoupe { + struct Facet { + address facetAddress; + bytes4[] functionSelectors; + } + + /// @notice Gets all facet addresses and their four byte function selectors. + /// @return facets_ Facet + function facets() external view returns (Facet[] memory facets_); + + /// @notice Gets all the function selectors supported by a specific facet. + /// @param _facet The facet address. + /// @return facetFunctionSelectors_ + function facetFunctionSelectors(address _facet) external view returns (bytes4[] memory facetFunctionSelectors_); + + /// @notice Get all the facet addresses used by a diamond. + /// @return facetAddresses_ + function facetAddresses() external view returns (address[] memory facetAddresses_); + + /// @notice Gets the facet that supports the given selector. + /// @dev If facet is not found return address(0). + /// @param _functionSelector The function selector. + /// @return facetAddress_ The facet address. + function facetAddress(bytes4 _functionSelector) external view returns (address facetAddress_); +} diff --git a/contracts/proxy/dispatch/modules/DiamondCutFacet.sol b/contracts/proxy/dispatch/modules/DiamondCutFacet.sol new file mode 100644 index 00000000..ba2958b7 --- /dev/null +++ b/contracts/proxy/dispatch/modules/DiamondCutFacet.sol @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Address} from "@openzeppelin/contracts/utils/Address.sol"; +import {Context} from "@openzeppelin/contracts/utils/Context.sol"; +import {IDiamondCut} from "../interfaces/IDiamondCut.sol"; +import {Dispatch} from "../utils/Dispatch.sol"; + +/// @custom:stateless +contract DiamondCutFacet is Context, IDiamondCut { + using Dispatch for Dispatch.VMT; + + error DiamondCutFacetAlreadyExist(bytes4 selector); + error DiamondCutFacetAlreadySet(bytes4 selector); + error DiamondCutFacetAlreadyDoesNotExit(bytes4 selector); + + function diamondCut(FacetCut[] calldata _diamondCut, address _init, bytes calldata _calldata) public override { + Dispatch.VMT storage store = Dispatch.instance(); + + store.enforceOwner(_msgSender()); + for (uint256 i = 0; i < _diamondCut.length; ++i) { + FacetCut memory facetcut = _diamondCut[i]; + for (uint256 j = 0; j < facetcut.functionSelectors.length; ++j) { + bytes4 selector = facetcut.functionSelectors[j]; + address currentFacet = store.getFunction(selector); + if (facetcut.action == FacetCutAction.Add && currentFacet != address(0)) { + revert DiamondCutFacetAlreadyExist(selector); + } else if (facetcut.action == FacetCutAction.Replace && currentFacet != facetcut.facetAddress) { + revert DiamondCutFacetAlreadySet(selector); + } else if (facetcut.action == FacetCutAction.Remove && currentFacet == address(0)) { + revert DiamondCutFacetAlreadyDoesNotExit(selector); + } + store.setFunction(selector, facetcut.facetAddress); + } + } + + emit DiamondCut(_diamondCut, _init, _calldata); + + Address.functionCall(_init, _calldata); + } +} diff --git a/contracts/proxy/dispatch/modules/DiamondLoupeFacet.sol b/contracts/proxy/dispatch/modules/DiamondLoupeFacet.sol new file mode 100644 index 00000000..91ec0f29 --- /dev/null +++ b/contracts/proxy/dispatch/modules/DiamondLoupeFacet.sol @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Context} from "@openzeppelin/contracts/utils/Context.sol"; +import {IDiamondLoupe} from "../interfaces/IDiamondLoupe.sol"; +import {Dispatch} from "../utils/Dispatch.sol"; + +/// @custom:stateless +contract DiamondLoupeFacet is Context, IDiamondLoupe { + using Dispatch for Dispatch.VMT; + + function facets() public view override returns (Facet[] memory) { + this; + revert("This implementation doesnt keep an index, use an offchain index instead"); + } + + function facetFunctionSelectors(address _facet) public view override returns (bytes4[] memory) { + this; + _facet; + revert("This implementation doesnt keep an index, use an offchain index instead"); + } + + function facetAddresses() public view override returns (address[] memory) { + this; + revert("This implementation doesnt keep an index, use an offchain index instead"); + } + + function facetAddress(bytes4 _functionSelector) public view override returns (address) { + return Dispatch.instance().getFunction(_functionSelector); + } +} diff --git a/contracts/proxy/dispatch/modules/DispatchOwnershipModule.sol b/contracts/proxy/dispatch/modules/DispatchOwnershipModule.sol new file mode 100644 index 00000000..b515ef57 --- /dev/null +++ b/contracts/proxy/dispatch/modules/DispatchOwnershipModule.sol @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; +import {Context} from "@openzeppelin/contracts/utils/Context.sol"; +import {Dispatch} from "../utils/Dispatch.sol"; + +/// @custom:stateless +contract DispatchOwnershipModule is Context { + using Dispatch for Dispatch.VMT; + + /** + * @dev Throws if called by any account other than the owner. + */ + modifier onlyOwner() { + Dispatch.instance().enforceOwner(_msgSender()); + _; + } + + /** + * @dev Reads ownership for the vtable + */ + function owner() public view virtual returns (address) { + return Dispatch.instance().getOwner(); + } + + /** + * @dev Leaves the contract without owner. It will not be possible to call + * `onlyOwner` functions anymore. Can only be called by the current owner. + * + * NOTE: Renouncing ownership will leave the contract without an owner, + * thereby removing any functionality that is only available to the owner. + */ + function renounceOwnership() public virtual onlyOwner { + Dispatch.instance().setOwner(address(0)); + } + + /** + * @dev Transfers ownership of the contract to a new account (`newOwner`). + * Can only be called by the current owner. + */ + function transferOwnership(address newOwner) public virtual onlyOwner { + require(newOwner != address(0), Ownable.OwnableInvalidOwner(newOwner)); + Dispatch.instance().setOwner(newOwner); + } +} diff --git a/contracts/proxy/dispatch/modules/DispatchUpdateModule.sol b/contracts/proxy/dispatch/modules/DispatchUpdateModule.sol new file mode 100644 index 00000000..c1bbf575 --- /dev/null +++ b/contracts/proxy/dispatch/modules/DispatchUpdateModule.sol @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Context} from "@openzeppelin/contracts/utils/Context.sol"; +import {Dispatch} from "../utils/Dispatch.sol"; + +/// @custom:stateless +contract DispatchUpdateModule is Context { + using Dispatch for Dispatch.VMT; + + struct ModuleDefinition { + address implementation; + bytes4[] selectors; + } + + /** + * @dev Updates the vtable + */ + function updateDispatchTable(ModuleDefinition[] calldata modules) public { + Dispatch.VMT storage store = Dispatch.instance(); + + store.enforceOwner(_msgSender()); + for (uint256 i = 0; i < modules.length; ++i) { + ModuleDefinition memory module = modules[i]; + for (uint256 j = 0; j < module.selectors.length; ++j) { + store.setFunction(module.selectors[j], module.implementation); + } + } + } +} diff --git a/contracts/proxy/dispatch/utils/Dispatch.sol b/contracts/proxy/dispatch/utils/Dispatch.sol new file mode 100644 index 00000000..409db6bb --- /dev/null +++ b/contracts/proxy/dispatch/utils/Dispatch.sol @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol"; + +/** + * @title Dispatch + * @dev TODO + */ +library Dispatch { + // keccak256(abi.encode(uint256(keccak256("openzeppelin.storage.Dispatch.VMT")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 private constant _DISPATCH_VMT_SLOT = 0xe6b1591f932b472559c00c679d5b3da28bf0ed2fd643b2ef77392cbec1743c00; + + struct VMT { + address _owner; + mapping(bytes4 => address) _vtable; + } + + /** + * @dev Get singleton instance + */ + function instance() internal pure returns (VMT storage store) { + bytes32 position = _DISPATCH_VMT_SLOT; + assembly { + store.slot := position + } + } + + /** + * @dev Ownership management + */ + function getOwner(VMT storage store) internal view returns (address) { + return store._owner; + } + + function setOwner(VMT storage store, address newOwner) internal { + emit Ownable.OwnershipTransferred(store._owner, newOwner); + store._owner = newOwner; + } + + function enforceOwner(VMT storage store, address account) internal view { + require(getOwner(store) == account, Ownable.OwnableUnauthorizedAccount(account)); + } + + /** + * @dev Delegation management + */ + event VMTUpdate(bytes4 indexed selector, address oldImplementation, address newImplementation); + + function getFunction(VMT storage store, bytes4 selector) internal view returns (address) { + return store._vtable[selector]; + } + + function setFunction(VMT storage store, bytes4 selector, address module) internal { + emit VMTUpdate(selector, store._vtable[selector], module); + store._vtable[selector] = module; + } +} diff --git a/test/proxy/DispatchProxy.test.js b/test/proxy/DispatchProxy.test.js new file mode 100644 index 00000000..e309cc86 --- /dev/null +++ b/test/proxy/DispatchProxy.test.js @@ -0,0 +1,154 @@ +const { ethers } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); + +const getSelectors = iface => + [] + .concat( + iface.fragments.filter(({ type }) => type == 'function').map(({ selector }) => selector), + iface.receive && '0x00000000', + iface.fallback && '0xFFFFFFFF', + ) + .filter(Boolean); + +async function fixture() { + const [admin, other] = await ethers.getSigners(); + + const modules = { + diamondCut: await ethers.deployContract('DiamondCutFacet'), + diamondLoupe: await ethers.deployContract('DiamondLoupeFacet'), + ownership: await ethers.deployContract('DispatchOwnershipModule'), + update: await ethers.deployContract('DispatchUpdateModule'), + mock: await ethers.deployContract('DispatchModuleMock'), + }; + const proxy = await ethers.deployContract('DispatchProxy', [modules.update, admin]); + const proxyAsUpdate = modules.update.attach(proxy.target); + + return { admin, other, modules, proxy, proxyAsUpdate }; +} + +describe('DispatchProxy', async function () { + beforeEach('deploying', async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + it('missing implementation', async function () { + await expect(this.other.sendTransaction({ to: this.proxy })).to.be.revertedWithCustomError( + this.proxy, + 'DispatchProxyMissingImplementation', + ); + }); + + describe('dispatch table update', function () { + it('authorized', async function () { + const modules = [this.modules.diamondCut, this.modules.diamondLoupe, this.modules.ownership]; + + const tx = this.proxyAsUpdate + .connect(this.admin) + .updateDispatchTable(modules.map(module => [module, getSelectors(module.interface)])); + for (const module of modules) { + for (const selector of getSelectors(module.interface)) { + await expect(tx).to.emit(this.proxyAsUpdate, 'VMTUpdate').withArgs(selector, ethers.ZeroAddress, module); + } + } + }); + + it('unauthorized', async function () { + const modules = [this.modules.diamondCut, this.modules.diamondLoupe, this.modules.ownership]; + + await expect( + this.proxyAsUpdate + .connect(this.other) + .updateDispatchTable(modules.map(module => [module, getSelectors(module.interface)])), + ) + .to.be.revertedWithCustomError(this.proxyAsUpdate, 'OwnableUnauthorizedAccount') + .withArgs(this.other); + }); + + it('empty update', async function () { + const tx = await this.proxyAsUpdate.connect(this.admin).updateDispatchTable([]); + const receipt = await tx.wait(); + + expect(receipt.logs.length).to.be.equal(0); + }); + + it('receive', async function () { + const receiver = await ethers.deployContract('EtherReceiverMock'); + await this.proxyAsUpdate.connect(this.admin).updateDispatchTable([[receiver, getSelectors(receiver.interface)]]); + + // does not accept eth + await expect(this.other.sendTransaction({ to: this.proxy, value: 1n })).to.be.revertedWithoutReason(); + + // set accept eth + await receiver.attach(this.proxy.target).setAcceptEther(true); + + // accept eth + await expect(this.other.sendTransaction({ to: this.proxy, value: 1n })).to.not.be.reverted; + }); + + it('fallback', async function () { + const receiver = await ethers.deployContract('EtherReceiverMock'); + await this.proxyAsUpdate.connect(this.admin).updateDispatchTable([[receiver, ['0xffffffff']]]); + + // does not accept eth + await expect(this.other.sendTransaction({ to: this.proxy, value: 1n })).to.be.revertedWithoutReason(); + + // set accept eth + await receiver.attach(this.proxy.target).setAcceptEther(true); + + // accept eth + await expect(this.other.sendTransaction({ to: this.proxy, value: 1n })).to.not.be.reverted; + }); + }); + + describe('with ownership module', function () { + beforeEach(async function () { + await this.proxyAsUpdate + .connect(this.admin) + .updateDispatchTable([[this.modules.ownership, getSelectors(this.modules.ownership.interface)]]); + this.proxyAsOwnership = this.modules.ownership.attach(this.proxy.target); + }); + + it('has an owner', async function () { + await expect(this.proxyAsOwnership.owner()).to.eventually.equal(this.admin); + }); + + describe('transfer ownership', function () { + it('changes owner after transfer', async function () { + await expect(this.proxyAsOwnership.connect(this.admin).transferOwnership(this.other)) + .to.emit(this.proxyAsOwnership, 'OwnershipTransferred') + .withArgs(this.admin, this.other); + + await expect(this.proxyAsOwnership.owner()).to.eventually.equal(this.other); + }); + + it('prevents non-owners from transferring', async function () { + await expect(this.proxyAsOwnership.connect(this.other).transferOwnership(this.other)) + .to.be.revertedWithCustomError(this.proxyAsOwnership, 'OwnableUnauthorizedAccount') + .withArgs(this.other); + }); + + it('guards ownership against stuck state', async function () { + await expect(this.proxyAsOwnership.connect(this.admin).transferOwnership(ethers.ZeroAddress)) + .to.be.revertedWithCustomError(this.proxyAsOwnership, 'OwnableInvalidOwner') + .withArgs(ethers.ZeroAddress); + }); + }); + + describe('renounce ownership', function () { + it('loses owner after renouncement', async function () { + await expect(this.proxyAsOwnership.connect(this.admin).renounceOwnership()) + .to.emit(this.proxyAsOwnership, 'OwnershipTransferred') + .withArgs(this.admin, ethers.ZeroAddress); + + await expect(this.proxyAsOwnership.owner()).to.eventually.equal(ethers.ZeroAddress); + }); + + it('prevents non-owners from renouncement', async function () { + await expect(this.proxyAsOwnership.connect(this.other).renounceOwnership()) + .to.be.revertedWithCustomError(this.proxyAsOwnership, 'OwnableUnauthorizedAccount') + .withArgs(this.other); + }); + }); + }); +});