diff --git a/.changeset/spotty-plums-brush.md b/.changeset/spotty-plums-brush.md new file mode 100644 index 00000000000..c43ced1f85d --- /dev/null +++ b/.changeset/spotty-plums-brush.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`ERC4626`: Allow overriding underlying assets transfer mechanisms through new internal virtual functions (`_transferIn` and `_transferOut`). diff --git a/contracts/token/ERC20/extensions/ERC4626.sol b/contracts/token/ERC20/extensions/ERC4626.sol index 201972815a0..08760a70618 100644 --- a/contracts/token/ERC20/extensions/ERC4626.sol +++ b/contracts/token/ERC20/extensions/ERC4626.sol @@ -267,7 +267,7 @@ abstract contract ERC4626 is ERC20, IERC4626 { // Conclusion: we need to do the transfer before we mint so that any reentrancy would happen before the // assets are transferred and before the shares are minted, which is a valid state. // slither-disable-next-line reentrancy-no-eth - SafeERC20.safeTransferFrom(IERC20(asset()), caller, address(this), assets); + _transferIn(caller, assets); _mint(receiver, shares); emit Deposit(caller, receiver, assets, shares); @@ -294,11 +294,21 @@ abstract contract ERC4626 is ERC20, IERC4626 { // Conclusion: we need to do the transfer after the burn so that any reentrancy would happen after the // shares are burned and after the assets are transferred, which is a valid state. _burn(owner, shares); - SafeERC20.safeTransfer(IERC20(asset()), receiver, assets); + _transferOut(receiver, assets); emit Withdraw(caller, receiver, owner, assets, shares); } + /// @dev Performs a transfer in of underlying assets. The default implementation uses `SafeERC20`. Used by {_deposit}. + function _transferIn(address from, uint256 assets) internal virtual { + SafeERC20.safeTransferFrom(IERC20(asset()), from, address(this), assets); + } + + /// @dev Performs a transfer out of underlying assets. The default implementation uses `SafeERC20`. Used by {_withdraw}. + function _transferOut(address to, uint256 assets) internal virtual { + SafeERC20.safeTransfer(IERC20(asset()), to, assets); + } + function _decimalsOffset() internal view virtual returns (uint8) { return 0; }