Skip to content

Commit 1c8d4b4

Browse files
Sabnock01mds1
andauthored
feat: Refactor assumeNot* std cheats (#407)
* fix: exclude vm address from assumeNoPrecompiles * forge fmt * oops * update assumePayableIsNot * forge fmt * forge fmt v2 * remove exposed_assumePayable test * modify _isPayable to be compiler agnostic * compiler agnostic v2 * warnings * minor fixes * minor fixes v2 * fix: make suggested changes * forge fmt * fix: correct use of console address * fix: remove extcodesize check * chore: tweak function arg order and other small cleanup --------- Co-authored-by: Matt Solomon <[email protected]>
1 parent adec12d commit 1c8d4b4

File tree

2 files changed

+156
-24
lines changed

2 files changed

+156
-24
lines changed

src/StdCheats.sol

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import {Vm} from "./Vm.sol";
99
abstract contract StdCheatsSafe {
1010
Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
1111

12+
uint256 private constant UINT256_MAX =
13+
115792089237316195423570985008687907853269984665640564039457584007913129639935;
14+
1215
bool private gasMeteringOff;
1316

1417
// Data structures to parse Transaction objects from the broadcast artifact
@@ -193,6 +196,14 @@ abstract contract StdCheatsSafe {
193196
uint256 key;
194197
}
195198

199+
enum AddressType {
200+
Payable,
201+
NonPayable,
202+
ZeroAddress,
203+
Precompile,
204+
ForgeAddress
205+
}
206+
196207
// Checks that `addr` is not blacklisted by token contracts that have a blacklist.
197208
function assumeNotBlacklisted(address token, address addr) internal view virtual {
198209
// Nothing to check if `token` is not a contract.
@@ -222,11 +233,91 @@ abstract contract StdCheatsSafe {
222233
assumeNotBlacklisted(token, addr);
223234
}
224235

225-
function assumeNoPrecompiles(address addr) internal pure virtual {
226-
assumeNoPrecompiles(addr, _pureChainId());
236+
function assumeAddressIsNot(address addr, AddressType addressType) internal virtual {
237+
if (addressType == AddressType.Payable) {
238+
assumeNotPayable(addr);
239+
} else if (addressType == AddressType.NonPayable) {
240+
assumePayable(addr);
241+
} else if (addressType == AddressType.ZeroAddress) {
242+
assumeNotZeroAddress(addr);
243+
} else if (addressType == AddressType.Precompile) {
244+
assumeNotPrecompile(addr);
245+
} else if (addressType == AddressType.ForgeAddress) {
246+
assumeNotForgeAddress(addr);
247+
}
248+
}
249+
250+
function assumeAddressIsNot(address addr, AddressType addressType1, AddressType addressType2) internal virtual {
251+
assumeAddressIsNot(addr, addressType1);
252+
assumeAddressIsNot(addr, addressType2);
253+
}
254+
255+
function assumeAddressIsNot(
256+
address addr,
257+
AddressType addressType1,
258+
AddressType addressType2,
259+
AddressType addressType3
260+
) internal virtual {
261+
assumeAddressIsNot(addr, addressType1);
262+
assumeAddressIsNot(addr, addressType2);
263+
assumeAddressIsNot(addr, addressType3);
264+
}
265+
266+
function assumeAddressIsNot(
267+
address addr,
268+
AddressType addressType1,
269+
AddressType addressType2,
270+
AddressType addressType3,
271+
AddressType addressType4
272+
) internal virtual {
273+
assumeAddressIsNot(addr, addressType1);
274+
assumeAddressIsNot(addr, addressType2);
275+
assumeAddressIsNot(addr, addressType3);
276+
assumeAddressIsNot(addr, addressType4);
277+
}
278+
279+
// This function checks whether an address, `addr`, is payable. It works by sending 1 wei to
280+
// `addr` and checking the `success` return value.
281+
// NOTE: This function may result in state changes depending on the fallback/receive logic
282+
// implemented by `addr`, which should be taken into account when this function is used.
283+
function _isPayable(address addr) private returns (bool) {
284+
require(
285+
addr.balance < UINT256_MAX,
286+
"StdCheats _isPayable(address): Balance equals max uint256, so it cannot receive any more funds"
287+
);
288+
uint256 origBalanceTest = address(this).balance;
289+
uint256 origBalanceAddr = address(addr).balance;
290+
291+
vm.deal(address(this), 1);
292+
(bool success,) = payable(addr).call{value: 1}("");
293+
294+
// reset balances
295+
vm.deal(address(this), origBalanceTest);
296+
vm.deal(addr, origBalanceAddr);
297+
298+
return success;
299+
}
300+
301+
// NOTE: This function may result in state changes depending on the fallback/receive logic
302+
// implemented by `addr`, which should be taken into account when this function is used. See the
303+
// `_isPayable` method for more information.
304+
function assumePayable(address addr) internal virtual {
305+
vm.assume(_isPayable(addr));
306+
}
307+
308+
function assumeNotPayable(address addr) internal virtual {
309+
vm.assume(!_isPayable(addr));
227310
}
228311

229-
function assumeNoPrecompiles(address addr, uint256 chainId) internal pure virtual {
312+
function assumeNotZeroAddress(address addr) internal pure virtual {
313+
vm.assume(addr != address(0));
314+
}
315+
316+
function assumeNotPrecompile(address addr) internal pure virtual {
317+
assumeNotPrecompile(addr, _pureChainId());
318+
}
319+
320+
function assumeNotPrecompile(address addr, uint256 chainId) internal pure virtual {
230321
// Note: For some chains like Optimism these are technically predeploys (i.e. bytecode placed at a specific
231322
// address), but the same rationale for excluding them applies so we include those too.
232323

@@ -249,6 +340,11 @@ abstract contract StdCheatsSafe {
249340
// forgefmt: disable-end
250341
}
251342

343+
function assumeNotForgeAddress(address addr) internal pure virtual {
344+
// vm and console addresses
345+
vm.assume(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67);
346+
}
347+
252348
function readEIP1559ScriptArtifact(string memory path)
253349
internal
254350
view
@@ -512,13 +608,6 @@ abstract contract StdCheatsSafe {
512608
}
513609
}
514610

515-
// a cheat for fuzzing addresses that are payable only
516-
// see https://github.com/foundry-rs/foundry/issues/3631
517-
function assumePayable(address addr) internal virtual {
518-
(bool success,) = payable(addr).call{value: 0}("");
519-
vm.assume(success);
520-
}
521-
522611
// We use this complex approach of `_viewChainId` and `_pureChainId` to ensure there are no
523612
// compiler warnings when accessing chain ID in any solidity version supported by forge-std. We
524613
// can't simply access the chain ID in a normal view or pure function because the solc View Pure

test/StdCheats.t.sol

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,21 +335,20 @@ contract StdCheatsTest is Test {
335335
return number;
336336
}
337337

338-
function testAssumeNoPrecompiles(address addr) external {
339-
assumeNoPrecompiles(addr, getChain("optimism_goerli").chainId);
340-
assertTrue(
341-
addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000))
342-
|| addr > address(0x4200000000000000000000000000000000000800)
343-
);
344-
}
345-
346-
function _assumePayable(address addr) public {
347-
assumePayable(addr);
338+
function testAssumeAddressIsNot(address addr) external {
339+
// skip over Payable and NonPayable enums
340+
for (uint8 i = 2; i < uint8(type(AddressType).max); i++) {
341+
assumeAddressIsNot(addr, AddressType(i));
342+
}
343+
assertTrue(addr != address(0));
344+
assertTrue(addr < address(1) || addr > address(9));
345+
assertTrue(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67);
348346
}
349347

350348
function testAssumePayable() external {
351349
// We deploy a mock version so we can properly test the revert.
352350
StdCheatsMock stdCheatsMock = new StdCheatsMock();
351+
353352
// all should revert since these addresses are not payable
354353

355354
// VM address
@@ -363,13 +362,49 @@ contract StdCheatsTest is Test {
363362
// Create2Deployer
364363
vm.expectRevert();
365364
stdCheatsMock.exposed_assumePayable(0x4e59b44847b379578588920cA78FbF26c0B4956C);
365+
366+
// all should pass since these addresses are payable
367+
368+
// vitalik.eth
369+
stdCheatsMock.exposed_assumePayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045);
370+
371+
// mock payable contract
372+
MockContractPayable cp = new MockContractPayable();
373+
stdCheatsMock.exposed_assumePayable(address(cp));
366374
}
367375

368-
function testAssumePayable(address addr) external {
369-
assumePayable(addr);
376+
function testAssumeNotPayable() external {
377+
// We deploy a mock version so we can properly test the revert.
378+
StdCheatsMock stdCheatsMock = new StdCheatsMock();
379+
380+
// all should pass since these addresses are not payable
381+
382+
// VM address
383+
stdCheatsMock.exposed_assumeNotPayable(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D);
384+
385+
// Console address
386+
stdCheatsMock.exposed_assumeNotPayable(0x000000000000000000636F6e736F6c652e6c6f67);
387+
388+
// Create2Deployer
389+
stdCheatsMock.exposed_assumeNotPayable(0x4e59b44847b379578588920cA78FbF26c0B4956C);
390+
391+
// all should revert since these addresses are payable
392+
393+
// vitalik.eth
394+
vm.expectRevert();
395+
stdCheatsMock.exposed_assumeNotPayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045);
396+
397+
// mock payable contract
398+
MockContractPayable cp = new MockContractPayable();
399+
vm.expectRevert();
400+
stdCheatsMock.exposed_assumeNotPayable(address(cp));
401+
}
402+
403+
function testAssumeNotPrecompile(address addr) external {
404+
assumeNotPrecompile(addr, getChain("optimism_goerli").chainId);
370405
assertTrue(
371-
addr != 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D && addr != 0x000000000000000000636F6e736F6c652e6c6f67
372-
&& addr != 0x4e59b44847b379578588920cA78FbF26c0B4956C
406+
addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000))
407+
|| addr > address(0x4200000000000000000000000000000000000800)
373408
);
374409
}
375410

@@ -406,6 +441,10 @@ contract StdCheatsMock is StdCheats {
406441
assumePayable(addr);
407442
}
408443

444+
function exposed_assumeNotPayable(address addr) external {
445+
assumeNotPayable(addr);
446+
}
447+
409448
// We deploy a mock version so we can properly test expected reverts.
410449
function exposed_assumeNotBlacklisted(address token, address addr) external view {
411450
return assumeNotBlacklisted(token, addr);
@@ -557,3 +596,7 @@ contract MockContractWithConstructorArgs {
557596
z = _z;
558597
}
559598
}
599+
600+
contract MockContractPayable {
601+
receive() external payable {}
602+
}

0 commit comments

Comments
 (0)