Skip to content

Commit 4e7e6e5

Browse files
ernestognwAmxx
andauthored
Add bytes memory version of Math.modExp (#4893)
Co-authored-by: Hadrien Croubois <[email protected]>
1 parent ae1bafc commit 4e7e6e5

File tree

5 files changed

+180
-36
lines changed

5 files changed

+180
-36
lines changed

.changeset/shiny-poets-whisper.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
'openzeppelin-solidity': minor
33
---
44

5-
`Math`: Add `modExp` function that exposes the `EIP-198` precompile.
5+
`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions.

contracts/utils/math/Math.sol

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
pragma solidity ^0.8.20;
55

6-
import {Address} from "../Address.sol";
76
import {Panic} from "../Panic.sol";
87
import {SafeCast} from "./SafeCast.sol";
98

@@ -289,11 +288,7 @@ library Math {
289288
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
290289
(bool success, uint256 result) = tryModExp(b, e, m);
291290
if (!success) {
292-
if (m == 0) {
293-
Panic.panic(Panic.DIVISION_BY_ZERO);
294-
} else {
295-
revert Address.FailedInnerCall();
296-
}
291+
Panic.panic(Panic.DIVISION_BY_ZERO);
297292
}
298293
return result;
299294
}
@@ -335,6 +330,57 @@ library Math {
335330
}
336331
}
337332

333+
/**
334+
* @dev Variant of {modExp} that supports inputs of arbitrary length.
335+
*/
336+
function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
337+
(bool success, bytes memory result) = tryModExp(b, e, m);
338+
if (!success) {
339+
Panic.panic(Panic.DIVISION_BY_ZERO);
340+
}
341+
return result;
342+
}
343+
344+
/**
345+
* @dev Variant of {tryModExp} that supports inputs of arbitrary length.
346+
*/
347+
function tryModExp(
348+
bytes memory b,
349+
bytes memory e,
350+
bytes memory m
351+
) internal view returns (bool success, bytes memory result) {
352+
if (_zeroBytes(m)) return (false, new bytes(0));
353+
354+
uint256 mLen = m.length;
355+
356+
// Encode call args in result and move the free memory pointer
357+
result = abi.encodePacked(b.length, e.length, mLen, b, e, m);
358+
359+
/// @solidity memory-safe-assembly
360+
assembly {
361+
let dataPtr := add(result, 0x20)
362+
// Write result on top of args to avoid allocating extra memory.
363+
success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
364+
// Overwrite the length.
365+
// result.length > returndatasize() is guaranteed because returndatasize() == m.length
366+
mstore(result, mLen)
367+
// Set the memory pointer after the returned data.
368+
mstore(0x40, add(dataPtr, mLen))
369+
}
370+
}
371+
372+
/**
373+
* @dev Returns whether the provided byte array is zero.
374+
*/
375+
function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
376+
for (uint256 i = 0; i < byteArray.length; ++i) {
377+
if (byteArray[i] != 0) {
378+
return false;
379+
}
380+
}
381+
return true;
382+
}
383+
338384
/**
339385
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
340386
* towards zero.

test/helpers/math.js

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,31 @@ const max = (...values) => values.slice(1).reduce((x, y) => (x > y ? x : y), val
33
const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0));
44
const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0));
55

6+
// Computes modexp without BigInt overflow for large numbers
7+
function modExp(b, e, m) {
8+
let result = 1n;
9+
10+
// If e is a power of two, modexp can be calculated as:
11+
// for (let result = b, i = 0; i < log2(e); i++) result = modexp(result, 2, m)
12+
//
13+
// Given any natural number can be written in terms of powers of 2 (i.e. binary)
14+
// then modexp can be calculated for any e, by multiplying b**i for all i where
15+
// binary(e)[i] is 1 (i.e. a power of two).
16+
for (let base = b % m; e > 0n; base = base ** 2n % m) {
17+
// Least significant bit is 1
18+
if (e % 2n == 1n) {
19+
result = (result * base) % m;
20+
}
21+
22+
e /= 2n; // Binary pop
23+
}
24+
25+
return result;
26+
}
27+
628
module.exports = {
729
min,
830
max,
931
sum,
32+
modExp,
1033
};

test/utils/math/Math.t.sol

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,33 @@ contract MathTest is Test {
226226
}
227227
}
228228

229+
function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
230+
if (m == 0) {
231+
vm.expectRevert(stdError.divisionError);
232+
}
233+
bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
234+
assertEq(result.length, 0x20);
235+
uint256 res = abi.decode(result, (uint256));
236+
assertLt(res, m);
237+
assertEq(res, _nativeModExp(b, e, m));
238+
}
239+
240+
function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
241+
(bool success, bytes memory result) = Math.tryModExp(
242+
abi.encodePacked(b),
243+
abi.encodePacked(e),
244+
abi.encodePacked(m)
245+
);
246+
if (success) {
247+
assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20
248+
uint256 res = abi.decode(result, (uint256));
249+
assertLt(res, m);
250+
assertEq(res, _nativeModExp(b, e, m));
251+
} else {
252+
assertEq(result.length, 0);
253+
}
254+
}
255+
229256
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
230257
if (m == 1) return 0;
231258
uint256 r = 1;

test/utils/math/Math.test.js

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
44
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
55

66
const { Rounding } = require('../../helpers/enums');
7-
const { min, max } = require('../../helpers/math');
7+
const { min, max, modExp } = require('../../helpers/math');
88
const { generators } = require('../../helpers/random');
9+
const { range } = require('../../../scripts/helpers');
10+
const { product } = require('../../helpers/iterate');
911

1012
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
1113
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
1214

15+
const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width));
16+
const uint256 = value => ethers.Typed.uint256(value);
17+
bytes.zero = '0x';
18+
uint256.zero = 0n;
19+
1320
async function testCommutative(fn, lhs, rhs, expected, ...extra) {
1421
expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
1522
expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
@@ -141,24 +148,6 @@ describe('Math', function () {
141148
});
142149
});
143150

144-
describe('tryModExp', function () {
145-
it('is correctly returning true and calculating modulus', async function () {
146-
const base = 3n;
147-
const exponent = 200n;
148-
const modulus = 50n;
149-
150-
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
151-
});
152-
153-
it('is correctly returning false when modulus is 0', async function () {
154-
const base = 3n;
155-
const exponent = 200n;
156-
const modulus = 0n;
157-
158-
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
159-
});
160-
});
161-
162151
describe('max', function () {
163152
it('is correctly detected in both position', async function () {
164153
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
@@ -354,20 +343,79 @@ describe('Math', function () {
354343
});
355344

356345
describe('modExp', function () {
357-
it('is correctly calculating modulus', async function () {
358-
const base = 3n;
359-
const exponent = 200n;
360-
const modulus = 50n;
346+
for (const [name, type] of Object.entries({ uint256, bytes })) {
347+
describe(`with ${name} inputs`, function () {
348+
it('is correctly calculating modulus', async function () {
349+
const b = 3n;
350+
const e = 200n;
351+
const m = 50n;
352+
353+
expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
354+
});
361355

362-
expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
356+
it('is correctly reverting when modulus is zero', async function () {
357+
const b = 3n;
358+
const e = 200n;
359+
const m = 0n;
360+
361+
await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic(
362+
PANIC_CODES.DIVISION_BY_ZERO,
363+
);
364+
});
365+
});
366+
}
367+
368+
describe('with large bytes inputs', function () {
369+
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
370+
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
371+
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
372+
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
373+
)) {
374+
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
375+
const mLength = ethers.dataLength(ethers.toBeHex(m));
376+
377+
expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
378+
});
379+
}
363380
});
381+
});
382+
383+
describe('tryModExp', function () {
384+
for (const [name, type] of Object.entries({ uint256, bytes })) {
385+
describe(`with ${name} inputs`, function () {
386+
it('is correctly calculating modulus', async function () {
387+
const b = 3n;
388+
const e = 200n;
389+
const m = 50n;
390+
391+
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
392+
});
364393

365-
it('is correctly reverting when modulus is zero', async function () {
366-
const base = 3n;
367-
const exponent = 200n;
368-
const modulus = 0n;
394+
it('is correctly reverting when modulus is zero', async function () {
395+
const b = 3n;
396+
const e = 200n;
397+
const m = 0n;
369398

370-
await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
399+
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
400+
});
401+
});
402+
}
403+
404+
describe('with large bytes inputs', function () {
405+
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
406+
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
407+
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
408+
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
409+
)) {
410+
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
411+
const mLength = ethers.dataLength(ethers.toBeHex(m));
412+
413+
expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
414+
true,
415+
bytes(modExp(b, e, m), mLength).value,
416+
]);
417+
});
418+
}
371419
});
372420
});
373421

0 commit comments

Comments
 (0)