Skip to content

Commit e549422

Browse files
committed
[mlir] Lower logical not, bitwise not and ternary expressions
1 parent 3ae1568 commit e549422

File tree

16 files changed

+4142
-3779
lines changed

16 files changed

+4142
-3779
lines changed

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ class SolidityToMLIRPass {
337337
/// Returns the mlir expression for the binary operation.
338338
mlir::Value genExpr(BinaryOperation const &binOp);
339339

340+
/// Returns the mlir expressions for the conditional (ternary) operation.
341+
mlir::SmallVector<mlir::Value> genExprs(Conditional const &cond);
342+
340343
/// Returns the mlir expression for the call.
341344
mlir::SmallVector<mlir::Value> genExprs(FunctionCall const &call);
342345

@@ -355,8 +358,10 @@ class SolidityToMLIRPass {
355358
/// to the corresponding mlir type of `resTy`.
356359
mlir::Value genRValExpr(Expression const &expr,
357360
std::optional<mlir::Type> resTy = std::nullopt);
358-
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc);
359-
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr);
361+
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc,
362+
std::optional<mlir::Type> resTy = std::nullopt);
363+
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr,
364+
mlir::TypeRange resTys = {});
360365

361366
/// Generates an ir that assigns `rhs` to `lhs`.
362367
void genAssign(mlir::Value lhs, mlir::Value rhs, mlir::Location loc);
@@ -812,6 +817,31 @@ mlir::Value SolidityToMLIRPass::genExpr(UnaryOperation const &unaryOp) {
812817
b.create<mlir::sol::ConstantOp>(loc, b.getIntegerAttr(mlirTy, 0));
813818
return genBinExpr(Token::Sub, zero, expr, loc);
814819
}
820+
// Logical not
821+
case Token::Not: {
822+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
823+
mlir::Value zero = b.create<mlir::sol::ConstantOp>(
824+
loc, b.getIntegerAttr(expr.getType(), 0));
825+
return b.create<mlir::sol::CmpOp>(loc, mlir::sol::CmpPredicate::eq, expr,
826+
zero);
827+
}
828+
// Bitwise not (~x == x ^ -1)
829+
case Token::BitNot: {
830+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
831+
832+
// Convert bytes to int if needed.
833+
if (auto bytesTy = mlir::dyn_cast<mlir::sol::BytesType>(expr.getType())) {
834+
mlir::Type intTy =
835+
b.getIntegerType(bytesTy.getSize() * 8, /*isSigned=*/false);
836+
expr = genCast(expr, intTy);
837+
}
838+
839+
auto intTy = mlir::cast<mlir::IntegerType>(expr.getType());
840+
mlir::Value allOnes = b.create<mlir::sol::ConstantOp>(
841+
loc,
842+
b.getIntegerAttr(intTy, llvm::APInt::getAllOnes(intTy.getWidth())));
843+
return b.create<mlir::sol::XorOp>(loc, expr, allOnes);
844+
}
815845
default:
816846
break;
817847
}
@@ -873,6 +903,38 @@ mlir::Value SolidityToMLIRPass::genExpr(BinaryOperation const &binOp) {
873903
return genBinExpr(binOp.getOperator(), lhs, rhs, loc);
874904
}
875905

906+
mlir::SmallVector<mlir::Value>
907+
SolidityToMLIRPass::genExprs(Conditional const &cond) {
908+
mlir::Location loc = getLoc(cond);
909+
mlir::Value condVal = genRValExpr(cond.condition());
910+
911+
// Get result types - could be single type or tuple.
912+
mlir::SmallVector<mlir::Type> resTys;
913+
if (TupleType const *tupleTy =
914+
dynamic_cast<TupleType const *>(cond.annotation().type)) {
915+
for (const Type *astTy : tupleTy->components())
916+
resTys.push_back(getType(astTy));
917+
} else {
918+
resTys.push_back(getType(cond.annotation().type));
919+
}
920+
921+
auto ifOp =
922+
b.create<mlir::scf::IfOp>(loc, resTys, condVal, /*withElse=*/true);
923+
mlir::OpBuilder::InsertionGuard guard(b);
924+
925+
// True branch
926+
b.setInsertionPointToStart(&ifOp.getThenRegion().front());
927+
b.create<mlir::scf::YieldOp>(loc,
928+
genRValExprs(cond.trueExpression(), resTys));
929+
930+
// False branch
931+
b.setInsertionPointToStart(&ifOp.getElseRegion().front());
932+
b.create<mlir::scf::YieldOp>(loc,
933+
genRValExprs(cond.falseExpression(), resTys));
934+
935+
return ifOp.getResults();
936+
}
937+
876938
mlir::Value SolidityToMLIRPass::genExpr(IndexAccess const &idxAcc) {
877939
mlir::Location loc = getLoc(idxAcc);
878940

@@ -1820,6 +1882,13 @@ mlir::Value SolidityToMLIRPass::genLValExpr(Expression const &expr) {
18201882
return {};
18211883
}
18221884

1885+
// Conditional (ternary operator)
1886+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr)) {
1887+
mlir::SmallVector<mlir::Value> exprs = genExprs(*cond);
1888+
assert(exprs.size() == 1);
1889+
return exprs[0];
1890+
}
1891+
18231892
llvm_unreachable("NYI");
18241893
}
18251894

@@ -1833,38 +1902,46 @@ SolidityToMLIRPass::genLValExprs(Expression const &expr) {
18331902
if (const auto *call = dynamic_cast<FunctionCall const *>(&expr))
18341903
return genExprs(*call);
18351904

1905+
// Conditional (ternary)
1906+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr))
1907+
return genExprs(*cond);
1908+
18361909
mlir::SmallVector<mlir::Value, 1> vals;
18371910
vals.push_back(genLValExpr(expr));
18381911
return vals;
18391912
}
18401913

1841-
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val,
1842-
mlir::Location loc) {
1914+
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val, mlir::Location loc,
1915+
std::optional<mlir::Type> resTy) {
18431916
if (mlir::isa<mlir::sol::PointerType>(val.getType()))
1844-
return b.create<mlir::sol::LoadOp>(loc, val);
1917+
val = b.create<mlir::sol::LoadOp>(loc, val);
1918+
if (resTy)
1919+
return genCast(val, *resTy);
18451920
return val;
18461921
}
18471922

18481923
mlir::Value SolidityToMLIRPass::genRValExpr(Expression const &expr,
18491924
std::optional<mlir::Type> resTy) {
18501925
mlir::Value lVal = genLValExpr(expr);
18511926
assert(lVal);
1852-
1853-
mlir::Value val = genRValExpr(lVal, getLoc(expr));
1854-
// Generate cast (optional).
1855-
if (resTy)
1856-
return genCast(val, *resTy);
1857-
return val;
1927+
return genRValExpr(lVal, getLoc(expr), resTy);
18581928
}
18591929

18601930
mlir::SmallVector<mlir::Value>
1861-
SolidityToMLIRPass::genRValExprs(Expression const &expr) {
1931+
SolidityToMLIRPass::genRValExprs(Expression const &expr,
1932+
mlir::TypeRange resTys) {
18621933
mlir::SmallVector<mlir::Value> lVals = genLValExprs(expr);
18631934
assert(!lVals.empty());
1935+
assert(resTys.empty() || lVals.size() == resTys.size());
18641936

18651937
mlir::SmallVector<mlir::Value, 2> rVals;
1866-
for (mlir::Value lVal : lVals)
1867-
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
1938+
if (resTys.empty()) {
1939+
for (mlir::Value lVal : lVals)
1940+
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
1941+
} else {
1942+
for (auto [lVal, resTy] : llvm::zip(lVals, resTys))
1943+
rVals.push_back(genRValExpr(lVal, getLoc(expr), resTy));
1944+
}
18681945

18691946
return rVals;
18701947
}
@@ -1877,22 +1954,23 @@ void SolidityToMLIRPass::lower(
18771954
VariableDeclarationStatement const &varDeclStmt) {
18781955
mlir::Location loc = getLoc(varDeclStmt);
18791956

1957+
mlir::SmallVector<mlir::Type> varTys;
1958+
for (auto const &varDeclPtr : varDeclStmt.declarations())
1959+
varTys.push_back(getType(varDeclPtr->type(), /*indirectFn=*/true));
1960+
18801961
mlir::SmallVector<mlir::Value> initExprs(varDeclStmt.declarations().size());
18811962
if (Expression const *initExpr = varDeclStmt.initialValue())
1882-
initExprs = genRValExprs(*initExpr);
1883-
1884-
for (auto [varDeclPtr, initExpr] :
1885-
llvm::zip(varDeclStmt.declarations(), initExprs)) {
1886-
VariableDeclaration const &varDecl = *varDeclPtr;
1963+
initExprs = genRValExprs(*initExpr, varTys);
18871964

1888-
mlir::Type varTy = getType(varDecl.type(), /*indirectFn=*/true);
1965+
for (auto [varDeclPtr, varTy, initExpr] :
1966+
llvm::zip(varDeclStmt.declarations(), varTys, initExprs)) {
18891967
mlir::Type allocTy = mlir::sol::PointerType::get(
18901968
b.getContext(), varTy, mlir::sol::DataLocation::Stack);
18911969

18921970
auto addr = b.create<mlir::sol::AllocaOp>(loc, allocTy);
1893-
trackLocalVarAddr(varDecl, addr);
1971+
trackLocalVarAddr(*varDeclPtr, addr);
18941972
if (initExpr)
1895-
b.create<mlir::sol::StoreOp>(loc, genCast(initExpr, varTy), addr);
1973+
b.create<mlir::sol::StoreOp>(loc, initExpr, addr);
18961974
else
18971975
genZeroedVal(addr);
18981976
}
@@ -1923,22 +2001,17 @@ void SolidityToMLIRPass::lower(PlaceholderStatement const &placeholder) {
19232001
}
19242002

19252003
void SolidityToMLIRPass::lower(Return const &ret) {
1926-
TypePointers fnResTys;
2004+
mlir::SmallVector<mlir::Type> fnResTys;
19272005
for (ASTPointer<VariableDeclaration> const &retParam :
19282006
ret.annotation().function->returnParameters())
1929-
fnResTys.push_back(retParam->type());
2007+
fnResTys.push_back(getType(retParam->type()));
19302008

19312009
Expression const *astExpr = ret.expression();
1932-
if (astExpr) {
1933-
mlir::SmallVector<mlir::Value> exprs = genRValExprs(*astExpr);
1934-
mlir::SmallVector<mlir::Value> castedExprs;
1935-
for (auto [expr, dstTy] : llvm::zip(exprs, fnResTys)) {
1936-
castedExprs.push_back(genCast(expr, getType(dstTy)));
1937-
}
1938-
b.create<mlir::sol::ReturnOp>(getLoc(ret), castedExprs);
1939-
} else {
2010+
if (astExpr)
2011+
b.create<mlir::sol::ReturnOp>(getLoc(ret),
2012+
genRValExprs(*astExpr, fnResTys));
2013+
else
19402014
b.create<mlir::sol::ReturnOp>(getLoc(ret));
1941-
}
19422015
b.setInsertionPointToStart(b.createBlock(b.getBlock()->getParent()));
19432016
}
19442017

test/libsolidity/semanticTests/mlir/arith.sol

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ contract C {
7777
return (a & b, a | b, a ^ b);
7878
}
7979

80+
function bnot(uint a) public returns (uint) {
81+
return ~a;
82+
}
83+
84+
function bnotBytes(bytes4 a) public pure returns (bytes4) {
85+
return ~a;
86+
}
87+
8088
function addmod8(uint8 a, uint8 b, uint8 m) public returns (uint256) {
8189
return addmod(a, b, m);
8290
}
@@ -122,6 +130,8 @@ contract C {
122130
// shr8(uint8,uint256): 1, 8 -> 0
123131
// bit(int256,int256): 6, 3 -> 2, 7, 5
124132
// bit8(int8,int8): 6, 3 -> 2, 7, 5
133+
// bnot(uint256): 0 -> 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
134+
// bnotBytes(bytes4): left(0x12345678) -> left(0xedcba987)
125135
// addmod8(uint8,uint8,uint8): 42, 25, 0 -> FAILURE, hex"4e487b71", 0x12
126136
// addmod8(uint8,uint8,uint8): 42, 25, 24 -> 19
127137
// addmod8(uint8,uint8,uint8): 200, 100, 7 -> 6

test/libsolidity/semanticTests/mlir/cf.sol

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
contract C {
2-
function f(uint a, uint b) public returns (uint) {
2+
uint public counter;
3+
4+
function for_brk(uint a, uint b) public returns (uint) {
35
uint r = 1;
46
for (uint i = 0; i < a; ++i) {
57
if (i == b)
@@ -9,7 +11,7 @@ contract C {
911
return r;
1012
}
1113

12-
function g(uint a) public returns (uint) {
14+
function while_cont(uint a) public returns (uint) {
1315
uint r = 1;
1416
do {
1517
r = 2;
@@ -26,10 +28,48 @@ contract C {
2628
}
2729
return r;
2830
}
31+
32+
function tern(bool c, uint a, uint b) public pure returns (uint) {
33+
return c ? a : b;
34+
}
35+
36+
function tern_bytes(bool c, bytes4 a, bytes4 b) public pure returns (bytes4) {
37+
return c ? a : b;
38+
}
39+
40+
function tern_cast(bool c, uint8 a, uint256 b) public pure returns (uint256) {
41+
return c ? a : b;
42+
}
43+
44+
function tern_short_circuit(bool c) public returns (uint, uint) {
45+
counter = 0;
46+
uint result = c ? ++counter : ++counter;
47+
return (result, counter);
48+
}
49+
50+
function tern_tuple(bool c, uint a, uint b, uint x, uint y) public pure returns (uint, uint) {
51+
return c ? (a, b) : (x, y);
52+
}
53+
54+
function tern_const(bool c) public pure returns (uint) {
55+
return c ? 1 : 2;
56+
}
2957
}
3058

3159
// ====
3260
// compileViaMlir: true
3361
// ----
34-
// f(uint256,uint256): 20, 10 -> 1024
35-
// g(uint256): 10 -> 22
62+
// for_brk(uint256,uint256): 20, 10 -> 1024
63+
// while_cont(uint256): 10 -> 22
64+
// tern(bool,uint256,uint256): true, 10, 20 -> 10
65+
// tern(bool,uint256,uint256): false, 10, 20 -> 20
66+
// tern_bytes(bool,bytes4,bytes4): true, left(0x12345678), left(0xaabbccdd) -> left(0x12345678)
67+
// tern_bytes(bool,bytes4,bytes4): false, left(0x12345678), left(0xaabbccdd) -> left(0xaabbccdd)
68+
// tern_cast(bool,uint8,uint256): true, 42, 1000 -> 42
69+
// tern_cast(bool,uint8,uint256): false, 42, 1000 -> 1000
70+
// tern_short_circuit(bool): true -> 1, 1
71+
// tern_short_circuit(bool): false -> 1, 1
72+
// tern_tuple(bool,uint256,uint256,uint256,uint256): true, 10, 20, 30, 40 -> 10, 20
73+
// tern_tuple(bool,uint256,uint256,uint256,uint256): false, 10, 20, 30, 40 -> 30, 40
74+
// tern_const(bool): true -> 1
75+
// tern_const(bool): false -> 2

test/libsolidity/semanticTests/mlir/logical.sol

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ contract C {
2323
x = 0;
2424
return (a || s(b), x);
2525
}
26+
27+
function j(bool a) public pure returns (bool) {
28+
return !a;
29+
}
2630
}
2731

2832
// ====
@@ -42,3 +46,5 @@ contract C {
4246
// i(bool,bool): 1, 0 -> true, 0
4347
// i(bool,bool): 0, 1 -> true, 1
4448
// i(bool,bool): 0, 0 -> false, 1
49+
// j(bool): 0 -> true
50+
// j(bool): 1 -> false

0 commit comments

Comments
 (0)