Skip to content

Commit 02fdcb3

Browse files
authored
Merge pull request #12976 from dflupu/mulmod-opti
Add simplification rules for `mod(mul(X, Y), A)` & `mod(add(X, Y), A)`
2 parents 75300c3 + 8498fdf commit 02fdcb3

File tree

7 files changed

+122
-3
lines changed

7 files changed

+122
-3
lines changed

Changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Language Features:
66
Compiler Features:
77
* TypeChecker: Support using library constants in initializers of other constants.
88
* Yul IR Code Generation: Improved copy routines for arrays with packed storage layout.
9+
* Yul Optimizer: Add rule to convert `mod(mul(X, Y), A)` into `mulmod(X, Y, A)`, if `A` is a power of two.
10+
* Yul Optimizer: Add rule to convert `mod(add(X, Y), A)` into `addmod(X, Y, A)`, if `A` is a power of two.
911

1012

1113
Bugfixes:

libevmasm/RuleList.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,41 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart4_5(
275275

276276
template <class Pattern>
277277
std::vector<SimplificationRule<Pattern>> simplificationRuleListPart5(
278+
bool _forYulOptimizer,
278279
Pattern A,
279280
Pattern B,
280281
Pattern,
281282
Pattern X,
282-
Pattern
283+
Pattern Y
283284
)
284285
{
285286
using Word = typename Pattern::Word;
286287
using Builtins = typename Pattern::Builtins;
287288

288289
std::vector<SimplificationRule<Pattern>> rules;
289290

291+
// The libevmasm optimizer does not support rules resulting in opcodes with more than two arguments.
292+
if (_forYulOptimizer)
293+
{
294+
// Replace MOD(MUL(X, Y), A) with MULMOD(X, Y, A) iff A=2**N
295+
rules.push_back({
296+
Builtins::MOD(Builtins::MUL(X, Y), A),
297+
[=]() -> Pattern { return Builtins::MULMOD(X, Y, A); },
298+
[=] {
299+
return A.d() > 0 && ((A.d() & (A.d() - 1)) == 0);
300+
}
301+
});
302+
303+
// Replace MOD(ADD(X, Y), A) with ADDMOD(X, Y, A) iff A=2**N
304+
rules.push_back({
305+
Builtins::MOD(Builtins::ADD(X, Y), A),
306+
[=]() -> Pattern { return Builtins::ADDMOD(X, Y, A); },
307+
[=] {
308+
return A.d() > 0 && ((A.d() & (A.d() - 1)) == 0);
309+
}
310+
});
311+
}
312+
290313
// Replace MOD X, <power-of-two> with AND X, <power-of-two> - 1
291314
for (size_t i = 0; i < Pattern::WordSize; ++i)
292315
{
@@ -798,7 +821,7 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleList(
798821
rules += simplificationRuleListPart3(A, B, C, W, X);
799822
rules += simplificationRuleListPart4(A, B, C, W, X);
800823
rules += simplificationRuleListPart4_5(A, B, C, W, X);
801-
rules += simplificationRuleListPart5(A, B, C, W, X);
824+
rules += simplificationRuleListPart5(_evmVersion.has_value(), A, B, C, W, X);
802825
rules += simplificationRuleListPart6(A, B, C, W, X);
803826
rules += simplificationRuleListPart7(A, B, C, W, X);
804827
rules += simplificationRuleListPart8(A, B, C, W, X);

test/formal/mod_add_to_addmod.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from opcodes import MOD, ADD, ADDMOD
2+
from rule import Rule
3+
from z3 import BitVec
4+
5+
"""
6+
Rule:
7+
MOD(ADD(X, Y), A) -> ADDMOD(X, Y, A)
8+
given
9+
A > 0
10+
A & (A - 1) == 0
11+
"""
12+
13+
rule = Rule()
14+
15+
n_bits = 32
16+
17+
# Input vars
18+
X = BitVec('X', n_bits)
19+
Y = BitVec('Y', n_bits)
20+
A = BitVec('A', n_bits)
21+
22+
# Non optimized result
23+
nonopt = MOD(ADD(X, Y), A)
24+
25+
# Optimized result
26+
opt = ADDMOD(X, Y, A)
27+
28+
rule.require(A > 0)
29+
rule.require(((A & (A - 1)) == 0))
30+
31+
rule.check(nonopt, opt)

test/formal/mod_mul_to_mulmod.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from opcodes import MOD, MUL, MULMOD
2+
from rule import Rule
3+
from z3 import BitVec
4+
5+
"""
6+
Rule:
7+
MOD(MUL(X, Y), A) -> MULMOD(X, Y, A)
8+
given
9+
A > 0
10+
A & (A - 1) == 0
11+
"""
12+
13+
rule = Rule()
14+
15+
n_bits = 8
16+
17+
# Input vars
18+
X = BitVec('X', n_bits)
19+
Y = BitVec('Y', n_bits)
20+
A = BitVec('A', n_bits)
21+
22+
# Non optimized result
23+
nonopt = MOD(MUL(X, Y), A)
24+
25+
# Optimized result
26+
opt = MULMOD(X, Y, A)
27+
28+
rule.require(A > 0)
29+
rule.require(((A & (A - 1)) == 0))
30+
31+
rule.check(nonopt, opt)

test/formal/opcodes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from z3 import BitVecVal, BV2Int, If, LShR, UDiv, ULT, UGT, URem
1+
from z3 import BitVecVal, BV2Int, If, LShR, UDiv, ULT, UGT, URem, ZeroExt, Extract
22

33
def ADD(x, y):
44
return x + y
@@ -18,6 +18,12 @@ def SDIV(x, y):
1818
def MOD(x, y):
1919
return If(y == 0, 0, URem(x, y))
2020

21+
def MULMOD(x, y, m):
22+
return If(m == 0, 0, Extract(x.size() - 1, 0, URem(ZeroExt(x.size(), x) * ZeroExt(x.size(), y), ZeroExt(m.size(), m))))
23+
24+
def ADDMOD(x, y, m):
25+
return If(m == 0, 0, Extract(x.size() - 1, 0, URem(ZeroExt(1, x) + ZeroExt(1, y), ZeroExt(1, m))))
26+
2127
def SMOD(x, y):
2228
return If(y == 0, 0, x % y)
2329

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
mstore(0, mod(add(mload(0), mload(1)), 32))
3+
}
4+
// ----
5+
// step: expressionSimplifier
6+
//
7+
// {
8+
// {
9+
// let _3 := mload(1)
10+
// let _4 := 0
11+
// mstore(_4, addmod(mload(_4), _3, 32))
12+
// }
13+
// }
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
mstore(0, mod(mul(mload(0), mload(1)), 32))
3+
}
4+
// ----
5+
// step: expressionSimplifier
6+
//
7+
// {
8+
// {
9+
// let _3 := mload(1)
10+
// let _4 := 0
11+
// mstore(_4, mulmod(mload(_4), _3, 32))
12+
// }
13+
// }

0 commit comments

Comments
 (0)