Skip to content

Commit 7c16530

Browse files
committed
[mlir] Lower logical not, bitwise not and ternary expressions
1 parent 03c05fb commit 7c16530

File tree

16 files changed

+4276
-3795
lines changed

16 files changed

+4276
-3795
lines changed

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ class SolidityToMLIRPass {
351351
/// Returns the mlir expression for the binary operation.
352352
mlir::Value genExpr(BinaryOperation const &binOp);
353353

354+
/// Returns the mlir expressions for the conditional (ternary) operation.
355+
mlir::SmallVector<mlir::Value> genExprs(Conditional const &cond);
356+
354357
/// Returns the mlir expression for the call.
355358
mlir::SmallVector<mlir::Value> genExprs(FunctionCall const &call);
356359

@@ -369,8 +372,10 @@ class SolidityToMLIRPass {
369372
/// to the corresponding mlir type of `resTy`.
370373
mlir::Value genRValExpr(Expression const &expr,
371374
std::optional<mlir::Type> resTy = std::nullopt);
372-
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc);
373-
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr);
375+
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc,
376+
std::optional<mlir::Type> resTy = std::nullopt);
377+
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr,
378+
mlir::TypeRange resTys = {});
374379

375380
/// Generates an ir that assigns `rhs` to `lhs`.
376381
void genAssign(mlir::Value lhs, mlir::Value rhs, mlir::Location loc);
@@ -849,6 +854,31 @@ mlir::Value SolidityToMLIRPass::genExpr(UnaryOperation const &unaryOp) {
849854
b.create<mlir::sol::ConstantOp>(loc, b.getIntegerAttr(mlirTy, 0));
850855
return genBinExpr(Token::Sub, zero, expr, loc);
851856
}
857+
// Logical not
858+
case Token::Not: {
859+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
860+
mlir::Value zero = b.create<mlir::sol::ConstantOp>(
861+
loc, b.getIntegerAttr(expr.getType(), 0));
862+
return b.create<mlir::sol::CmpOp>(loc, mlir::sol::CmpPredicate::eq, expr,
863+
zero);
864+
}
865+
// Bitwise not (~x == x ^ -1)
866+
case Token::BitNot: {
867+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
868+
869+
// Convert bytes to int if needed.
870+
if (auto bytesTy = mlir::dyn_cast<mlir::sol::BytesType>(expr.getType())) {
871+
mlir::Type intTy =
872+
b.getIntegerType(bytesTy.getSize() * 8, /*isSigned=*/false);
873+
expr = genCast(expr, intTy);
874+
}
875+
876+
auto intTy = mlir::cast<mlir::IntegerType>(expr.getType());
877+
mlir::Value allOnes = b.create<mlir::sol::ConstantOp>(
878+
loc,
879+
b.getIntegerAttr(intTy, llvm::APInt::getAllOnes(intTy.getWidth())));
880+
return b.create<mlir::sol::XorOp>(loc, expr, allOnes);
881+
}
852882
default:
853883
break;
854884
}
@@ -910,6 +940,38 @@ mlir::Value SolidityToMLIRPass::genExpr(BinaryOperation const &binOp) {
910940
return genBinExpr(binOp.getOperator(), lhs, rhs, loc);
911941
}
912942

943+
mlir::SmallVector<mlir::Value>
944+
SolidityToMLIRPass::genExprs(Conditional const &cond) {
945+
mlir::Location loc = getLoc(cond);
946+
mlir::Value condVal = genRValExpr(cond.condition());
947+
948+
// Get result types - could be single type or tuple.
949+
mlir::SmallVector<mlir::Type> resTys;
950+
if (TupleType const *tupleTy =
951+
dynamic_cast<TupleType const *>(cond.annotation().type)) {
952+
for (const Type *astTy : tupleTy->components())
953+
resTys.push_back(getType(astTy));
954+
} else {
955+
resTys.push_back(getType(cond.annotation().type));
956+
}
957+
958+
auto ifOp =
959+
b.create<mlir::scf::IfOp>(loc, resTys, condVal, /*withElse=*/true);
960+
mlir::OpBuilder::InsertionGuard guard(b);
961+
962+
// True branch
963+
b.setInsertionPointToStart(&ifOp.getThenRegion().front());
964+
b.create<mlir::scf::YieldOp>(loc,
965+
genRValExprs(cond.trueExpression(), resTys));
966+
967+
// False branch
968+
b.setInsertionPointToStart(&ifOp.getElseRegion().front());
969+
b.create<mlir::scf::YieldOp>(loc,
970+
genRValExprs(cond.falseExpression(), resTys));
971+
972+
return ifOp.getResults();
973+
}
974+
913975
mlir::Value SolidityToMLIRPass::genExpr(IndexAccess const &idxAcc) {
914976
mlir::Location loc = getLoc(idxAcc);
915977

@@ -1935,6 +1997,13 @@ mlir::Value SolidityToMLIRPass::genLValExpr(Expression const &expr) {
19351997
return {};
19361998
}
19371999

2000+
// Conditional (ternary operator)
2001+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr)) {
2002+
mlir::SmallVector<mlir::Value> exprs = genExprs(*cond);
2003+
assert(exprs.size() == 1);
2004+
return exprs[0];
2005+
}
2006+
19382007
llvm_unreachable("NYI");
19392008
}
19402009

@@ -1948,38 +2017,46 @@ SolidityToMLIRPass::genLValExprs(Expression const &expr) {
19482017
if (const auto *call = dynamic_cast<FunctionCall const *>(&expr))
19492018
return genExprs(*call);
19502019

2020+
// Conditional (ternary)
2021+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr))
2022+
return genExprs(*cond);
2023+
19512024
mlir::SmallVector<mlir::Value, 1> vals;
19522025
vals.push_back(genLValExpr(expr));
19532026
return vals;
19542027
}
19552028

1956-
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val,
1957-
mlir::Location loc) {
2029+
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val, mlir::Location loc,
2030+
std::optional<mlir::Type> resTy) {
19582031
if (mlir::isa<mlir::sol::PointerType>(val.getType()))
1959-
return b.create<mlir::sol::LoadOp>(loc, val);
2032+
val = b.create<mlir::sol::LoadOp>(loc, val);
2033+
if (resTy)
2034+
return genCast(val, *resTy);
19602035
return val;
19612036
}
19622037

19632038
mlir::Value SolidityToMLIRPass::genRValExpr(Expression const &expr,
19642039
std::optional<mlir::Type> resTy) {
19652040
mlir::Value lVal = genLValExpr(expr);
19662041
assert(lVal);
1967-
1968-
mlir::Value val = genRValExpr(lVal, getLoc(expr));
1969-
// Generate cast (optional).
1970-
if (resTy)
1971-
return genCast(val, *resTy);
1972-
return val;
2042+
return genRValExpr(lVal, getLoc(expr), resTy);
19732043
}
19742044

19752045
mlir::SmallVector<mlir::Value>
1976-
SolidityToMLIRPass::genRValExprs(Expression const &expr) {
2046+
SolidityToMLIRPass::genRValExprs(Expression const &expr,
2047+
mlir::TypeRange resTys) {
19772048
mlir::SmallVector<mlir::Value> lVals = genLValExprs(expr);
19782049
assert(!lVals.empty());
2050+
assert(resTys.empty() || lVals.size() == resTys.size());
19792051

19802052
mlir::SmallVector<mlir::Value, 2> rVals;
1981-
for (mlir::Value lVal : lVals)
1982-
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
2053+
if (resTys.empty()) {
2054+
for (mlir::Value lVal : lVals)
2055+
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
2056+
} else {
2057+
for (auto [lVal, resTy] : llvm::zip(lVals, resTys))
2058+
rVals.push_back(genRValExpr(lVal, getLoc(expr), resTy));
2059+
}
19832060

19842061
return rVals;
19852062
}
@@ -1992,22 +2069,23 @@ void SolidityToMLIRPass::lower(
19922069
VariableDeclarationStatement const &varDeclStmt) {
19932070
mlir::Location loc = getLoc(varDeclStmt);
19942071

2072+
mlir::SmallVector<mlir::Type> varTys;
2073+
for (auto const &varDeclPtr : varDeclStmt.declarations())
2074+
varTys.push_back(getType(varDeclPtr->type(), /*indirectFn=*/true));
2075+
19952076
mlir::SmallVector<mlir::Value> initExprs(varDeclStmt.declarations().size());
19962077
if (Expression const *initExpr = varDeclStmt.initialValue())
1997-
initExprs = genRValExprs(*initExpr);
1998-
1999-
for (auto [varDeclPtr, initExpr] :
2000-
llvm::zip(varDeclStmt.declarations(), initExprs)) {
2001-
VariableDeclaration const &varDecl = *varDeclPtr;
2078+
initExprs = genRValExprs(*initExpr, varTys);
20022079

2003-
mlir::Type varTy = getType(varDecl.type(), /*indirectFn=*/true);
2080+
for (auto [varDeclPtr, varTy, initExpr] :
2081+
llvm::zip(varDeclStmt.declarations(), varTys, initExprs)) {
20042082
mlir::Type allocTy = mlir::sol::PointerType::get(
20052083
b.getContext(), varTy, mlir::sol::DataLocation::Stack);
20062084

20072085
auto addr = b.create<mlir::sol::AllocaOp>(loc, allocTy);
2008-
trackLocalVarAddr(varDecl, addr);
2086+
trackLocalVarAddr(*varDeclPtr, addr);
20092087
if (initExpr)
2010-
b.create<mlir::sol::StoreOp>(loc, genCast(initExpr, varTy), addr);
2088+
b.create<mlir::sol::StoreOp>(loc, initExpr, addr);
20112089
else
20122090
genZeroedVal(addr);
20132091
}
@@ -2038,22 +2116,17 @@ void SolidityToMLIRPass::lower(PlaceholderStatement const &placeholder) {
20382116
}
20392117

20402118
void SolidityToMLIRPass::lower(Return const &ret) {
2041-
TypePointers fnResTys;
2119+
mlir::SmallVector<mlir::Type> fnResTys;
20422120
for (ASTPointer<VariableDeclaration> const &retParam :
20432121
ret.annotation().function->returnParameters())
2044-
fnResTys.push_back(retParam->type());
2122+
fnResTys.push_back(getType(retParam->type()));
20452123

20462124
Expression const *astExpr = ret.expression();
2047-
if (astExpr) {
2048-
mlir::SmallVector<mlir::Value> exprs = genRValExprs(*astExpr);
2049-
mlir::SmallVector<mlir::Value> castedExprs;
2050-
for (auto [expr, dstTy] : llvm::zip(exprs, fnResTys)) {
2051-
castedExprs.push_back(genCast(expr, getType(dstTy)));
2052-
}
2053-
b.create<mlir::sol::ReturnOp>(getLoc(ret), castedExprs);
2054-
} else {
2125+
if (astExpr)
2126+
b.create<mlir::sol::ReturnOp>(getLoc(ret),
2127+
genRValExprs(*astExpr, fnResTys));
2128+
else
20552129
b.create<mlir::sol::ReturnOp>(getLoc(ret));
2056-
}
20572130
b.setInsertionPointToStart(b.createBlock(b.getBlock()->getParent()));
20582131
}
20592132

test/libsolidity/semanticTests/mlir/arith.sol

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ contract C {
7777
return (a & b, a | b, a ^ b);
7878
}
7979

80+
function bnot(uint a) public returns (uint) {
81+
return ~a;
82+
}
83+
84+
function bnotBytes(bytes4 a) public pure returns (bytes4) {
85+
return ~a;
86+
}
87+
8088
function addmod8(uint8 a, uint8 b, uint8 m) public returns (uint256) {
8189
return addmod(a, b, m);
8290
}
@@ -123,6 +131,8 @@ contract C {
123131
// shr8(uint8,uint256): 1, 8 -> 0
124132
// bit(int256,int256): 6, 3 -> 2, 7, 5
125133
// bit8(int8,int8): 6, 3 -> 2, 7, 5
134+
// bnot(uint256): 0 -> 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
135+
// bnotBytes(bytes4): left(0x12345678) -> left(0xedcba987)
126136
// addmod8(uint8,uint8,uint8): 42, 25, 0 -> FAILURE, hex"4e487b71", 0x12
127137
// addmod8(uint8,uint8,uint8): 42, 25, 24 -> 19
128138
// addmod8(uint8,uint8,uint8): 200, 100, 7 -> 6

test/libsolidity/semanticTests/mlir/cf.sol

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
contract C {
2-
function f(uint a, uint b) public returns (uint) {
2+
uint public counter;
3+
4+
function for_brk(uint a, uint b) public returns (uint) {
35
uint r = 1;
46
for (uint i = 0; i < a; ++i) {
57
if (i == b)
@@ -9,7 +11,7 @@ contract C {
911
return r;
1012
}
1113

12-
function g(uint a) public returns (uint) {
14+
function while_cont(uint a) public returns (uint) {
1315
uint r = 1;
1416
do {
1517
r = 2;
@@ -26,10 +28,48 @@ contract C {
2628
}
2729
return r;
2830
}
31+
32+
function tern(bool c, uint a, uint b) public pure returns (uint) {
33+
return c ? a : b;
34+
}
35+
36+
function tern_bytes(bool c, bytes4 a, bytes4 b) public pure returns (bytes4) {
37+
return c ? a : b;
38+
}
39+
40+
function tern_cast(bool c, uint8 a, uint256 b) public pure returns (uint256) {
41+
return c ? a : b;
42+
}
43+
44+
function tern_short_circuit(bool c) public returns (uint, uint) {
45+
counter = 0;
46+
uint result = c ? ++counter : ++counter;
47+
return (result, counter);
48+
}
49+
50+
function tern_tuple(bool c, uint a, uint b, uint x, uint y) public pure returns (uint, uint) {
51+
return c ? (a, b) : (x, y);
52+
}
53+
54+
function tern_const(bool c) public pure returns (uint) {
55+
return c ? 1 : 2;
56+
}
2957
}
3058

3159
// ====
3260
// compileViaMlir: true
3361
// ----
34-
// f(uint256,uint256): 20, 10 -> 1024
35-
// g(uint256): 10 -> 22
62+
// for_brk(uint256,uint256): 20, 10 -> 1024
63+
// while_cont(uint256): 10 -> 22
64+
// tern(bool,uint256,uint256): true, 10, 20 -> 10
65+
// tern(bool,uint256,uint256): false, 10, 20 -> 20
66+
// tern_bytes(bool,bytes4,bytes4): true, left(0x12345678), left(0xaabbccdd) -> left(0x12345678)
67+
// tern_bytes(bool,bytes4,bytes4): false, left(0x12345678), left(0xaabbccdd) -> left(0xaabbccdd)
68+
// tern_cast(bool,uint8,uint256): true, 42, 1000 -> 42
69+
// tern_cast(bool,uint8,uint256): false, 42, 1000 -> 1000
70+
// tern_short_circuit(bool): true -> 1, 1
71+
// tern_short_circuit(bool): false -> 1, 1
72+
// tern_tuple(bool,uint256,uint256,uint256,uint256): true, 10, 20, 30, 40 -> 10, 20
73+
// tern_tuple(bool,uint256,uint256,uint256,uint256): false, 10, 20, 30, 40 -> 30, 40
74+
// tern_const(bool): true -> 1
75+
// tern_const(bool): false -> 2

test/libsolidity/semanticTests/mlir/logical.sol

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ contract C {
2323
x = 0;
2424
return (a || s(b), x);
2525
}
26+
27+
function j(bool a) public pure returns (bool) {
28+
return !a;
29+
}
2630
}
2731

2832
// ====
@@ -42,3 +46,5 @@ contract C {
4246
// i(bool,bool): 1, 0 -> true, 0
4347
// i(bool,bool): 0, 1 -> true, 1
4448
// i(bool,bool): 0, 0 -> false, 1
49+
// j(bool): 0 -> true
50+
// j(bool): 1 -> false

0 commit comments

Comments
 (0)