diff --git a/.changeset/quick-pianos-press.md b/.changeset/quick-pianos-press.md new file mode 100644 index 00000000000..e9eae71946c --- /dev/null +++ b/.changeset/quick-pianos-press.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`ReentrancyGuard` and `ReentrancyGuardTransient`: Add `nonReentrantView`, a read-only version of the `nonReentrant` modifier. diff --git a/contracts/mocks/ReentrancyAttack.sol b/contracts/mocks/ReentrancyAttack.sol index 3df2d1c2b23..fec1478153e 100644 --- a/contracts/mocks/ReentrancyAttack.sol +++ b/contracts/mocks/ReentrancyAttack.sol @@ -9,4 +9,9 @@ contract ReentrancyAttack is Context { (bool success, ) = _msgSender().call(data); require(success, "ReentrancyAttack: failed call"); } + + function staticcallSender(bytes calldata data) public view { + (bool success, ) = _msgSender().staticcall(data); + require(success, "ReentrancyAttack: failed call"); + } } diff --git a/contracts/mocks/ReentrancyMock.sol b/contracts/mocks/ReentrancyMock.sol index 39e2d5ed850..34971812920 100644 --- a/contracts/mocks/ReentrancyMock.sol +++ b/contracts/mocks/ReentrancyMock.sol @@ -16,6 +16,10 @@ contract ReentrancyMock is ReentrancyGuard { _count(); } + function viewCallback() external view nonReentrantView returns (uint256) { + return counter; + } + function countLocalRecursive(uint256 n) public nonReentrant { if (n > 0) { _count(); @@ -36,6 +40,11 @@ contract ReentrancyMock is ReentrancyGuard { attacker.callSender(abi.encodeCall(this.callback, ())); } + function countAndCallView(ReentrancyAttack attacker) public nonReentrant { + _count(); + attacker.staticcallSender(abi.encodeCall(this.viewCallback, ())); + } + function _count() private { counter += 1; } diff --git a/contracts/mocks/ReentrancyTransientMock.sol b/contracts/mocks/ReentrancyTransientMock.sol index f0e61ea8caa..436b245109d 100644 --- a/contracts/mocks/ReentrancyTransientMock.sol +++ b/contracts/mocks/ReentrancyTransientMock.sol @@ -16,6 +16,10 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient { _count(); } + function viewCallback() external view nonReentrantView returns (uint256) { + return counter; + } + function countLocalRecursive(uint256 n) public nonReentrant { if (n > 0) { _count(); @@ -36,6 +40,11 @@ contract ReentrancyTransientMock is ReentrancyGuardTransient { attacker.callSender(abi.encodeCall(this.callback, ())); } + function countAndCallView(ReentrancyAttack attacker) public nonReentrant { + _count(); + attacker.staticcallSender(abi.encodeCall(this.viewCallback, ())); + } + function _count() private { counter += 1; } diff --git a/contracts/utils/ReentrancyGuard.sol b/contracts/utils/ReentrancyGuard.sol index a95fb512f31..854c48bb944 100644 --- a/contracts/utils/ReentrancyGuard.sol +++ b/contracts/utils/ReentrancyGuard.sol @@ -61,6 +61,25 @@ abstract contract ReentrancyGuard { _nonReentrantAfter(); } + /** + * @dev A `view` only version of {nonReentrant}. Use to block view functions + * from being called, preventing reading from inconsistent contract state. + * + * CAUTION: This is a "view" modifier and does not change the reentrancy + * status. Use it only on view functions. For payable or non-payable functions, + * use the standard {nonReentrant} modifier instead. + */ + modifier nonReentrantView() { + _nonReentrantBeforeView(); + _; + } + + function _nonReentrantBeforeView() private view { + if (_status == ENTERED) { + revert ReentrancyGuardReentrantCall(); + } + } + function _nonReentrantBefore() private { // On the first call to nonReentrant, _status will be NOT_ENTERED if (_status == ENTERED) { diff --git a/contracts/utils/ReentrancyGuardTransient.sol b/contracts/utils/ReentrancyGuardTransient.sol index a1318c86f3c..701d587f9b8 100644 --- a/contracts/utils/ReentrancyGuardTransient.sol +++ b/contracts/utils/ReentrancyGuardTransient.sol @@ -37,6 +37,25 @@ abstract contract ReentrancyGuardTransient { _nonReentrantAfter(); } + /** + * @dev A `view` only version of {nonReentrant}. Use to block view functions + * from being called, preventing reading from inconsistent contract state. + * + * CAUTION: This is a "view" modifier and does not change the reentrancy + * status. Use it only on view functions. For payable or non-payable functions, + * use the standard {nonReentrant} modifier instead. + */ + modifier nonReentrantView() { + _nonReentrantBeforeView(); + _; + } + + function _nonReentrantBeforeView() private view { + if (_reentrancyGuardEntered()) { + revert ReentrancyGuardReentrantCall(); + } + } + function _nonReentrantBefore() private { // On the first call to nonReentrant, REENTRANCY_GUARD_STORAGE.asBoolean().tload() will be false if (_reentrancyGuardEntered()) { diff --git a/test/utils/ReentrancyGuard.test.js b/test/utils/ReentrancyGuard.test.js index c4418563eb5..4a157864998 100644 --- a/test/utils/ReentrancyGuard.test.js +++ b/test/utils/ReentrancyGuard.test.js @@ -7,7 +7,8 @@ for (const variant of ['', 'Transient']) { async function fixture() { const name = `Reentrancy${variant}Mock`; const mock = await ethers.deployContract(name); - return { name, mock }; + const attacker = await ethers.deployContract('ReentrancyAttack'); + return { name, mock, attacker }; } beforeEach(async function () { @@ -20,9 +21,16 @@ for (const variant of ['', 'Transient']) { expect(await this.mock.counter()).to.equal(1n); }); - it('does not allow remote callback', async function () { - const attacker = await ethers.deployContract('ReentrancyAttack'); - await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call'); + it('nonReentrantView function can be called', async function () { + await this.mock.viewCallback(); + }); + + it('does not allow remote callback to nonReentrant function', async function () { + await expect(this.mock.countAndCall(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call'); + }); + + it('does not allow remote callback to nonReentrantView function', async function () { + await expect(this.mock.countAndCallView(this.attacker)).to.be.revertedWith('ReentrancyAttack: failed call'); }); it('_reentrancyGuardEntered should be true when guarded', async function () {