Skip to content

Commit 5015284

Browse files
authored
Merge pull request #11748 from ethereum/optimize_signextend
Optimizer rules for signextend.
2 parents b62bb0a + 4480662 commit 5015284

File tree

8 files changed

+223
-3
lines changed

8 files changed

+223
-3
lines changed

libevmasm/RuleList.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart4(
248248

249249
template <class Pattern>
250250
std::vector<SimplificationRule<Pattern>> simplificationRuleListPart4_5(
251-
Pattern,
252-
Pattern,
251+
Pattern A,
252+
Pattern B,
253253
Pattern,
254254
Pattern X,
255255
Pattern Y
@@ -266,13 +266,17 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart4_5(
266266
{Builtins::OR(Y, Builtins::OR(X, Y)), [=]{ return Builtins::OR(X, Y); }},
267267
{Builtins::OR(Builtins::OR(Y, X), Y), [=]{ return Builtins::OR(Y, X); }},
268268
{Builtins::OR(Y, Builtins::OR(Y, X)), [=]{ return Builtins::OR(Y, X); }},
269+
{Builtins::SIGNEXTEND(X, Builtins::SIGNEXTEND(X, Y)), [=]() { return Builtins::SIGNEXTEND(X, Y); }},
270+
{Builtins::SIGNEXTEND(A, Builtins::SIGNEXTEND(B, X)), [=]() {
271+
return Builtins::SIGNEXTEND(A.d() < B.d() ? A.d() : B.d(), X);
272+
}},
269273
};
270274
}
271275

272276
template <class Pattern>
273277
std::vector<SimplificationRule<Pattern>> simplificationRuleListPart5(
274278
Pattern A,
275-
Pattern,
279+
Pattern B,
276280
Pattern,
277281
Pattern X,
278282
Pattern
@@ -314,6 +318,31 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart5(
314318
[=]() { return A.d() >= Pattern::WordSize / 8; }
315319
});
316320

321+
// Replace SIGNEXTEND(A, X), A >= 31 with ID
322+
rules.push_back({
323+
Builtins::SIGNEXTEND(A, X),
324+
[=]() -> Pattern { return X; },
325+
[=]() { return A.d() >= Pattern::WordSize / 8 - 1; }
326+
});
327+
rules.push_back({
328+
Builtins::AND(A, Builtins::SIGNEXTEND(B, X)),
329+
[=]() -> Pattern { return Builtins::AND(A, X); },
330+
[=]() {
331+
return
332+
B.d() < Pattern::WordSize / 8 - 1 &&
333+
(A.d() & ((u256(1) << static_cast<size_t>((B.d() + 1) * 8)) - 1)) == A.d();
334+
}
335+
});
336+
rules.push_back({
337+
Builtins::AND(Builtins::SIGNEXTEND(B, X), A),
338+
[=]() -> Pattern { return Builtins::AND(A, X); },
339+
[=]() {
340+
return
341+
B.d() < Pattern::WordSize / 8 - 1 &&
342+
(A.d() & ((u256(1) << static_cast<size_t>((B.d() + 1) * 8)) - 1)) == A.d();
343+
}
344+
});
345+
317346
for (auto instr: {
318347
Instruction::ADDRESS,
319348
Instruction::CALLER,
@@ -597,6 +626,24 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart7(
597626
}
598627
});
599628

629+
rules.push_back({
630+
Builtins::SHL(A, Builtins::SIGNEXTEND(B, X)),
631+
[=]() -> Pattern { return Builtins::SIGNEXTEND((A.d() >> 3) + B.d(), Builtins::SHL(A, X)); },
632+
[=] { return (A.d() & 7) == 0 && A.d() <= Pattern::WordSize && B.d() <= Pattern::WordSize / 8; }
633+
});
634+
635+
rules.push_back({
636+
Builtins::SIGNEXTEND(A, Builtins::SHR(B, X)),
637+
[=]() -> Pattern { return Builtins::SAR(B, X); },
638+
[=] {
639+
return
640+
B.d() % 8 == 0 &&
641+
B.d() <= Pattern::WordSize &&
642+
A.d() <= Pattern::WordSize &&
643+
(Pattern::WordSize - B.d()) / 8 == A.d() + 1;
644+
}
645+
});
646+
600647
return rules;
601648
}
602649

test/formal/opcodes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,18 @@ def BYTE(i, x):
6464
BitVecVal(0, x.size()),
6565
(LShR(x, (x.size() - bit))) & 0xff
6666
)
67+
68+
def SIGNEXTEND(i, x):
69+
bitBV = i * 8 + 7
70+
bitInt = BV2Int(i) * 8 + 7
71+
test = BitVecVal(1, x.size()) << bitBV
72+
mask = test - 1
73+
return If(
74+
bitInt >= x.size(),
75+
x,
76+
If(
77+
(x & test) == 0,
78+
x & mask,
79+
x | ~mask
80+
)
81+
)

test/formal/signextend.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from rule import Rule
2+
from opcodes import *
3+
4+
"""
5+
Rule:
6+
1) SIGNEXTEND(A, X) -> X if A >= Pattern::WordSize / 8 - 1;
7+
8+
2) SIGNEXTEND(X, SIGNEXTEND(X, Y)) -> SIGNEXTEND(X, Y)
9+
10+
3) SIGNEXTEND(A, SIGNEXTEND(B, X)) -> SIGNEXTEND(min(A, B), X)
11+
"""
12+
13+
n_bits = 128
14+
15+
# Input vars
16+
X = BitVec('X', n_bits)
17+
Y = BitVec('Y', n_bits)
18+
A = BitVec('A', n_bits)
19+
B = BitVec('B', n_bits)
20+
21+
rule1 = Rule()
22+
# Requirements
23+
rule1.require(UGE(A, BitVecVal(n_bits // 8 - 1, n_bits)))
24+
rule1.check(SIGNEXTEND(A, X), X)
25+
26+
rule2 = Rule()
27+
rule2.check(
28+
SIGNEXTEND(X, SIGNEXTEND(X, Y)),
29+
SIGNEXTEND(X, Y)
30+
)
31+
32+
rule3 = Rule()
33+
rule3.check(
34+
SIGNEXTEND(A, SIGNEXTEND(B, X)),
35+
SIGNEXTEND(If(ULT(A, B), A, B), X)
36+
)
37+

test/formal/signextend_and.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from rule import Rule
2+
from opcodes import *
3+
4+
"""
5+
Rule:
6+
AND(A, SIGNEXTEND(B, X)) -> AND(A, X)
7+
given
8+
B < WordSize / 8 - 1 AND
9+
A & (1 << ((B + 1) * 8) - 1) == A
10+
"""
11+
12+
n_bits = 128
13+
14+
# Input vars
15+
X = BitVec('X', n_bits)
16+
A = BitVec('A', n_bits)
17+
B = BitVec('B', n_bits)
18+
19+
rule = Rule()
20+
# Requirements
21+
rule.require(ULT(B, BitVecVal(n_bits // 8 - 1, n_bits)))
22+
rule.require((A & ((BitVecVal(1, n_bits) << ((B + 1) * 8)) - 1)) == A)
23+
rule.check(
24+
AND(A, SIGNEXTEND(B, X)),
25+
AND(A, X)
26+
)

test/formal/signextend_equivalence.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from rule import Rule
2+
from opcodes import *
3+
4+
"""
5+
Checking the implementation of SIGNEXTEND using Z3's native SignExt and Extract
6+
"""
7+
8+
rule = Rule()
9+
n_bits = 256
10+
11+
x = BitVec('X', n_bits)
12+
13+
def SIGNEXTEND_native(i, x):
14+
return SignExt(256 - 8 * i - 8, Extract(8 * i + 7, 0, x))
15+
16+
for i in range(0, 32):
17+
rule.check(
18+
SIGNEXTEND(BitVecVal(i, n_bits), x),
19+
SIGNEXTEND_native(i, x)
20+
)
21+
22+
i = BitVec('I', n_bits)
23+
rule.require(UGT(i, BitVecVal(31, n_bits)))
24+
rule.check(SIGNEXTEND(i, x), x)

test/formal/signextend_shl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from rule import Rule
2+
from opcodes import *
3+
4+
"""
5+
Rule:
6+
SHL(A, SIGNEXTEND(B, X)) -> SIGNEXTEND((A >> 3) + B, SHL(A, X))
7+
given return A & 7 == 0 AND A <= WordSize AND B <= WordSize / 8
8+
"""
9+
10+
n_bits = 256
11+
12+
# Input vars
13+
X = BitVec('X', n_bits)
14+
Y = BitVec('Y', n_bits)
15+
A = BitVec('A', n_bits)
16+
B = BitVec('B', n_bits)
17+
18+
rule = Rule()
19+
rule.require(A & 7 == 0)
20+
rule.require(ULE(A, n_bits))
21+
rule.require(ULE(B, n_bits / 8))
22+
rule.check(
23+
SHL(A, SIGNEXTEND(B, X)),
24+
SIGNEXTEND(LShR(A, 3) + B, SHL(A, X))
25+
)

test/formal/signextend_shr.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from rule import Rule
2+
from opcodes import *
3+
4+
"""
5+
Rule:
6+
SIGNEXTEND(A, SHR(B, X)) -> SAR(B, X)
7+
given
8+
B % 8 == 0 AND
9+
A <= WordSize AND
10+
B <= wordSize AND
11+
(WordSize - B) / 8 == A + 1
12+
"""
13+
14+
n_bits = 256
15+
16+
# Input vars
17+
X = BitVec('X', n_bits)
18+
Y = BitVec('Y', n_bits)
19+
A = BitVec('A', n_bits)
20+
B = BitVec('B', n_bits)
21+
22+
rule = Rule()
23+
rule.require(B % 8 == 0)
24+
rule.require(ULE(A, n_bits))
25+
rule.require(ULE(B, n_bits))
26+
rule.require((BitVecVal(n_bits, n_bits) - B) / 8 == A + 1)
27+
rule.check(
28+
SIGNEXTEND(A, SHR(B, X)),
29+
SAR(B, X)
30+
)
31+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
let x := calldataload(0)
3+
let a := shl(sub(256, 8), signextend(0, x))
4+
sstore(0, a)
5+
}
6+
// ====
7+
// EVMVersion: >=constantinople
8+
// ----
9+
// step: fullSuite
10+
//
11+
// {
12+
// {
13+
// sstore(0, shl(248, calldataload(0)))
14+
// }
15+
// }

0 commit comments

Comments
 (0)