Skip to content

Commit 17a8955

Browse files
CodeSandwichAmxxernestognw
authored
Optimize Math operations using branchless bool to uint translation. (#4878)
Co-authored-by: Hadrien Croubois <[email protected]> Co-authored-by: ernestognw <[email protected]>
1 parent 0a757ec commit 17a8955

File tree

5 files changed

+94
-62
lines changed

5 files changed

+94
-62
lines changed

.changeset/nervous-pans-grow.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': patch
3+
---
4+
5+
`SafeCast`: Add `toUint(bool)` for operating on `bool` values as `uint256`.

contracts/utils/math/Math.sol

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pragma solidity ^0.8.20;
55

66
import {Address} from "../Address.sol";
77
import {Panic} from "../Panic.sol";
8+
import {SafeCast} from "./SafeCast.sol";
89

910
/**
1011
* @dev Standard math utilities missing in the Solidity language.
@@ -210,11 +211,7 @@ library Math {
210211
* @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
211212
*/
212213
function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
213-
uint256 result = mulDiv(x, y, denominator);
214-
if (unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0) {
215-
result += 1;
216-
}
217-
return result;
214+
return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
218215
}
219216

220217
/**
@@ -383,7 +380,7 @@ library Math {
383380
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
384381
unchecked {
385382
uint256 result = sqrt(a);
386-
return result + (unsignedRoundsUp(rounding) && result * result < a ? 1 : 0);
383+
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a);
387384
}
388385
}
389386

@@ -393,38 +390,37 @@ library Math {
393390
*/
394391
function log2(uint256 value) internal pure returns (uint256) {
395392
uint256 result = 0;
393+
uint256 exp;
396394
unchecked {
397-
if (value >> 128 > 0) {
398-
value >>= 128;
399-
result += 128;
400-
}
401-
if (value >> 64 > 0) {
402-
value >>= 64;
403-
result += 64;
404-
}
405-
if (value >> 32 > 0) {
406-
value >>= 32;
407-
result += 32;
408-
}
409-
if (value >> 16 > 0) {
410-
value >>= 16;
411-
result += 16;
412-
}
413-
if (value >> 8 > 0) {
414-
value >>= 8;
415-
result += 8;
416-
}
417-
if (value >> 4 > 0) {
418-
value >>= 4;
419-
result += 4;
420-
}
421-
if (value >> 2 > 0) {
422-
value >>= 2;
423-
result += 2;
424-
}
425-
if (value >> 1 > 0) {
426-
result += 1;
427-
}
395+
exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
396+
value >>= exp;
397+
result += exp;
398+
399+
exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
400+
value >>= exp;
401+
result += exp;
402+
403+
exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
404+
value >>= exp;
405+
result += exp;
406+
407+
exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
408+
value >>= exp;
409+
result += exp;
410+
411+
exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
412+
value >>= exp;
413+
result += exp;
414+
415+
exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
416+
value >>= exp;
417+
result += exp;
418+
419+
exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
420+
value >>= exp;
421+
result += exp;
422+
423+
result += SafeCast.toUint(value > 1);
428424
}
429425
return result;
430426
}
@@ -436,7 +432,7 @@ library Math {
436432
function log2(uint256 value, Rounding rounding) internal pure returns (uint256) {
437433
unchecked {
438434
uint256 result = log2(value);
439-
return result + (unsignedRoundsUp(rounding) && 1 << result < value ? 1 : 0);
435+
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value);
440436
}
441437
}
442438

@@ -485,7 +481,7 @@ library Math {
485481
function log10(uint256 value, Rounding rounding) internal pure returns (uint256) {
486482
unchecked {
487483
uint256 result = log10(value);
488-
return result + (unsignedRoundsUp(rounding) && 10 ** result < value ? 1 : 0);
484+
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value);
489485
}
490486
}
491487

@@ -497,26 +493,25 @@ library Math {
497493
*/
498494
function log256(uint256 value) internal pure returns (uint256) {
499495
uint256 result = 0;
496+
uint256 isGt;
500497
unchecked {
501-
if (value >> 128 > 0) {
502-
value >>= 128;
503-
result += 16;
504-
}
505-
if (value >> 64 > 0) {
506-
value >>= 64;
507-
result += 8;
508-
}
509-
if (value >> 32 > 0) {
510-
value >>= 32;
511-
result += 4;
512-
}
513-
if (value >> 16 > 0) {
514-
value >>= 16;
515-
result += 2;
516-
}
517-
if (value >> 8 > 0) {
518-
result += 1;
519-
}
498+
isGt = SafeCast.toUint(value > (1 << 128) - 1);
499+
value >>= isGt * 128;
500+
result += isGt * 16;
501+
502+
isGt = SafeCast.toUint(value > (1 << 64) - 1);
503+
value >>= isGt * 64;
504+
result += isGt * 8;
505+
506+
isGt = SafeCast.toUint(value > (1 << 32) - 1);
507+
value >>= isGt * 32;
508+
result += isGt * 4;
509+
510+
isGt = SafeCast.toUint(value > (1 << 16) - 1);
511+
value >>= isGt * 16;
512+
result += isGt * 2;
513+
514+
result += SafeCast.toUint(value > (1 << 8) - 1);
520515
}
521516
return result;
522517
}
@@ -528,7 +523,7 @@ library Math {
528523
function log256(uint256 value, Rounding rounding) internal pure returns (uint256) {
529524
unchecked {
530525
uint256 result = log256(value);
531-
return result + (unsignedRoundsUp(rounding) && 1 << (result << 3) < value ? 1 : 0);
526+
return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value);
532527
}
533528
}
534529

contracts/utils/math/SafeCast.sol

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
pragma solidity ^0.8.20;
66

77
/**
8-
* @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
8+
* @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow
99
* checks.
1010
*
1111
* Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
@@ -1150,4 +1150,14 @@ library SafeCast {
11501150
}
11511151
return int256(value);
11521152
}
1153+
1154+
/**
1155+
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
1156+
*/
1157+
function toUint(bool b) internal pure returns (uint256 u) {
1158+
/// @solidity memory-safe-assembly
1159+
assembly {
1160+
u := iszero(iszero(b))
1161+
}
1162+
}
11531163
}

scripts/generate/templates/SafeCast.js

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ const header = `\
77
pragma solidity ^0.8.20;
88
99
/**
10-
* @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow
10+
* @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow
1111
* checks.
1212
*
1313
* Downcasting from uint256/int256 in Solidity does not revert on overflow. This can
@@ -116,11 +116,23 @@ function toUint${length}(int${length} value) internal pure returns (uint${length
116116
}
117117
`;
118118

119+
const boolToUint = `
120+
/**
121+
* @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump.
122+
*/
123+
function toUint(bool b) internal pure returns (uint256 u) {
124+
/// @solidity memory-safe-assembly
125+
assembly {
126+
u := iszero(iszero(b))
127+
}
128+
}
129+
`;
130+
119131
// GENERATE
120132
module.exports = format(
121133
header.trimEnd(),
122134
'library SafeCast {',
123135
errors,
124-
[...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256)],
136+
[...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint],
125137
'}',
126138
);

test/utils/math/SafeCast.test.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,14 @@ describe('SafeCast', function () {
146146
.withArgs(ethers.MaxUint256);
147147
});
148148
});
149+
150+
describe('toUint (bool)', function () {
151+
it('toUint(false) should be 0', async function () {
152+
expect(await this.mock.$toUint(false)).to.equal(0n);
153+
});
154+
155+
it('toUint(true) should be 1', async function () {
156+
expect(await this.mock.$toUint(true)).to.equal(1n);
157+
});
158+
});
149159
});

0 commit comments

Comments
 (0)