Skip to content

Commit 0f9626c

Browse files
committed
[mlir] Lower logical not, bitwise not and ternary expressions
1 parent a84c6d7 commit 0f9626c

File tree

17 files changed

+4230
-3895
lines changed

17 files changed

+4230
-3895
lines changed

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ class SolidityToMLIRPass {
335335
/// Returns the mlir expression for the binary operation.
336336
mlir::Value genExpr(BinaryOperation const &binOp);
337337

338+
/// Returns the mlir expressions for the conditional (ternary) operation.
339+
mlir::SmallVector<mlir::Value> genExprs(Conditional const &cond);
340+
338341
/// Returns the mlir expression for the call.
339342
mlir::SmallVector<mlir::Value> genExprs(FunctionCall const &call);
340343

@@ -353,8 +356,10 @@ class SolidityToMLIRPass {
353356
/// to the corresponding mlir type of `resTy`.
354357
mlir::Value genRValExpr(Expression const &expr,
355358
std::optional<mlir::Type> resTy = std::nullopt);
356-
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc);
357-
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr);
359+
mlir::Value genRValExpr(mlir::Value val, mlir::Location loc,
360+
std::optional<mlir::Type> resTy = std::nullopt);
361+
mlir::SmallVector<mlir::Value> genRValExprs(Expression const &expr,
362+
mlir::TypeRange resTys = {});
358363

359364
/// Generates an ir that assigns `rhs` to `lhs`.
360365
void genAssign(mlir::Value lhs, mlir::Value rhs, mlir::Location loc);
@@ -770,6 +775,31 @@ mlir::Value SolidityToMLIRPass::genExpr(UnaryOperation const &unaryOp) {
770775
b.create<mlir::sol::ConstantOp>(loc, b.getIntegerAttr(mlirTy, 0));
771776
return genBinExpr(Token::Sub, zero, expr, loc);
772777
}
778+
// Logical not
779+
case Token::Not: {
780+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
781+
mlir::Value zero = b.create<mlir::sol::ConstantOp>(
782+
loc, b.getIntegerAttr(expr.getType(), 0));
783+
return b.create<mlir::sol::CmpOp>(loc, mlir::sol::CmpPredicate::eq, expr,
784+
zero);
785+
}
786+
// Bitwise not (~x == x ^ -1)
787+
case Token::BitNot: {
788+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
789+
790+
// Convert bytes to int if needed.
791+
if (auto bytesTy = mlir::dyn_cast<mlir::sol::BytesType>(expr.getType())) {
792+
mlir::Type intTy =
793+
b.getIntegerType(bytesTy.getSize() * 8, /*isSigned=*/false);
794+
expr = genCast(expr, intTy);
795+
}
796+
797+
auto intTy = mlir::cast<mlir::IntegerType>(expr.getType());
798+
mlir::Value allOnes = b.create<mlir::sol::ConstantOp>(
799+
loc,
800+
b.getIntegerAttr(intTy, llvm::APInt::getAllOnes(intTy.getWidth())));
801+
return b.create<mlir::sol::XorOp>(loc, expr, allOnes);
802+
}
773803
default:
774804
break;
775805
}
@@ -831,6 +861,38 @@ mlir::Value SolidityToMLIRPass::genExpr(BinaryOperation const &binOp) {
831861
return genBinExpr(binOp.getOperator(), lhs, rhs, loc);
832862
}
833863

864+
mlir::SmallVector<mlir::Value>
865+
SolidityToMLIRPass::genExprs(Conditional const &cond) {
866+
mlir::Location loc = getLoc(cond);
867+
mlir::Value condVal = genRValExpr(cond.condition());
868+
869+
// Get result types - could be single type or tuple.
870+
mlir::SmallVector<mlir::Type> resTys;
871+
if (TupleType const *tupleTy =
872+
dynamic_cast<TupleType const *>(cond.annotation().type)) {
873+
for (const Type *astTy : tupleTy->components())
874+
resTys.push_back(getType(astTy));
875+
} else {
876+
resTys.push_back(getType(cond.annotation().type));
877+
}
878+
879+
auto ifOp =
880+
b.create<mlir::scf::IfOp>(loc, resTys, condVal, /*withElse=*/true);
881+
mlir::OpBuilder::InsertionGuard guard(b);
882+
883+
// True branch
884+
b.setInsertionPointToStart(&ifOp.getThenRegion().front());
885+
b.create<mlir::scf::YieldOp>(loc,
886+
genRValExprs(cond.trueExpression(), resTys));
887+
888+
// False branch
889+
b.setInsertionPointToStart(&ifOp.getElseRegion().front());
890+
b.create<mlir::scf::YieldOp>(loc,
891+
genRValExprs(cond.falseExpression(), resTys));
892+
893+
return ifOp.getResults();
894+
}
895+
834896
mlir::Value SolidityToMLIRPass::genExpr(IndexAccess const &idxAcc) {
835897
mlir::Location loc = getLoc(idxAcc);
836898

@@ -1776,6 +1838,13 @@ mlir::Value SolidityToMLIRPass::genLValExpr(Expression const &expr) {
17761838
return {};
17771839
}
17781840

1841+
// Conditional (ternary operator)
1842+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr)) {
1843+
mlir::SmallVector<mlir::Value> exprs = genExprs(*cond);
1844+
assert(exprs.size() == 1);
1845+
return exprs[0];
1846+
}
1847+
17791848
llvm_unreachable("NYI");
17801849
}
17811850

@@ -1789,38 +1858,46 @@ SolidityToMLIRPass::genLValExprs(Expression const &expr) {
17891858
if (const auto *call = dynamic_cast<FunctionCall const *>(&expr))
17901859
return genExprs(*call);
17911860

1861+
// Conditional (ternary)
1862+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr))
1863+
return genExprs(*cond);
1864+
17921865
mlir::SmallVector<mlir::Value, 1> vals;
17931866
vals.push_back(genLValExpr(expr));
17941867
return vals;
17951868
}
17961869

1797-
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val,
1798-
mlir::Location loc) {
1870+
mlir::Value SolidityToMLIRPass::genRValExpr(mlir::Value val, mlir::Location loc,
1871+
std::optional<mlir::Type> resTy) {
17991872
if (mlir::isa<mlir::sol::PointerType>(val.getType()))
1800-
return b.create<mlir::sol::LoadOp>(loc, val);
1873+
val = b.create<mlir::sol::LoadOp>(loc, val);
1874+
if (resTy)
1875+
return genCast(val, *resTy);
18011876
return val;
18021877
}
18031878

18041879
mlir::Value SolidityToMLIRPass::genRValExpr(Expression const &expr,
18051880
std::optional<mlir::Type> resTy) {
18061881
mlir::Value lVal = genLValExpr(expr);
18071882
assert(lVal);
1808-
1809-
mlir::Value val = genRValExpr(lVal, getLoc(expr));
1810-
// Generate cast (optional).
1811-
if (resTy)
1812-
return genCast(val, *resTy);
1813-
return val;
1883+
return genRValExpr(lVal, getLoc(expr), resTy);
18141884
}
18151885

18161886
mlir::SmallVector<mlir::Value>
1817-
SolidityToMLIRPass::genRValExprs(Expression const &expr) {
1887+
SolidityToMLIRPass::genRValExprs(Expression const &expr,
1888+
mlir::TypeRange resTys) {
18181889
mlir::SmallVector<mlir::Value> lVals = genLValExprs(expr);
18191890
assert(!lVals.empty());
1891+
assert(resTys.empty() || lVals.size() == resTys.size());
18201892

18211893
mlir::SmallVector<mlir::Value, 2> rVals;
1822-
for (mlir::Value lVal : lVals)
1823-
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
1894+
if (resTys.empty()) {
1895+
for (mlir::Value lVal : lVals)
1896+
rVals.push_back(genRValExpr(lVal, getLoc(expr)));
1897+
} else {
1898+
for (auto [lVal, resTy] : llvm::zip(lVals, resTys))
1899+
rVals.push_back(genRValExpr(lVal, getLoc(expr), resTy));
1900+
}
18241901

18251902
return rVals;
18261903
}
@@ -1833,22 +1910,23 @@ void SolidityToMLIRPass::lower(
18331910
VariableDeclarationStatement const &varDeclStmt) {
18341911
mlir::Location loc = getLoc(varDeclStmt);
18351912

1913+
mlir::SmallVector<mlir::Type> varTys;
1914+
for (auto const &varDeclPtr : varDeclStmt.declarations())
1915+
varTys.push_back(getType(varDeclPtr->type(), /*indirectFn=*/true));
1916+
18361917
mlir::SmallVector<mlir::Value> initExprs(varDeclStmt.declarations().size());
18371918
if (Expression const *initExpr = varDeclStmt.initialValue())
1838-
initExprs = genRValExprs(*initExpr);
1839-
1840-
for (auto [varDeclPtr, initExpr] :
1841-
llvm::zip(varDeclStmt.declarations(), initExprs)) {
1842-
VariableDeclaration const &varDecl = *varDeclPtr;
1919+
initExprs = genRValExprs(*initExpr, varTys);
18431920

1844-
mlir::Type varTy = getType(varDecl.type(), /*indirectFn=*/true);
1921+
for (auto [varDeclPtr, varTy, initExpr] :
1922+
llvm::zip(varDeclStmt.declarations(), varTys, initExprs)) {
18451923
mlir::Type allocTy = mlir::sol::PointerType::get(
18461924
b.getContext(), varTy, mlir::sol::DataLocation::Stack);
18471925

18481926
auto addr = b.create<mlir::sol::AllocaOp>(loc, allocTy);
1849-
trackLocalVarAddr(varDecl, addr);
1927+
trackLocalVarAddr(*varDeclPtr, addr);
18501928
if (initExpr)
1851-
b.create<mlir::sol::StoreOp>(loc, genCast(initExpr, varTy), addr);
1929+
b.create<mlir::sol::StoreOp>(loc, initExpr, addr);
18521930
else
18531931
genZeroedVal(addr);
18541932
}
@@ -1879,22 +1957,17 @@ void SolidityToMLIRPass::lower(PlaceholderStatement const &placeholder) {
18791957
}
18801958

18811959
void SolidityToMLIRPass::lower(Return const &ret) {
1882-
TypePointers fnResTys;
1960+
mlir::SmallVector<mlir::Type> fnResTys;
18831961
for (ASTPointer<VariableDeclaration> const &retParam :
18841962
ret.annotation().function->returnParameters())
1885-
fnResTys.push_back(retParam->type());
1963+
fnResTys.push_back(getType(retParam->type()));
18861964

18871965
Expression const *astExpr = ret.expression();
1888-
if (astExpr) {
1889-
mlir::SmallVector<mlir::Value> exprs = genRValExprs(*astExpr);
1890-
mlir::SmallVector<mlir::Value> castedExprs;
1891-
for (auto [expr, dstTy] : llvm::zip(exprs, fnResTys)) {
1892-
castedExprs.push_back(genCast(expr, getType(dstTy)));
1893-
}
1894-
b.create<mlir::sol::ReturnOp>(getLoc(ret), castedExprs);
1895-
} else {
1966+
if (astExpr)
1967+
b.create<mlir::sol::ReturnOp>(getLoc(ret),
1968+
genRValExprs(*astExpr, fnResTys));
1969+
else
18961970
b.create<mlir::sol::ReturnOp>(getLoc(ret));
1897-
}
18981971
b.setInsertionPointToStart(b.createBlock(b.getBlock()->getParent()));
18991972
}
19001973

test/libsolidity/semanticTests/mlir/arith.sol

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ contract C {
7777
return (a & b, a | b, a ^ b);
7878
}
7979

80+
function bnot(uint a) public returns (uint) {
81+
return ~a;
82+
}
83+
84+
function bnotBytes(bytes4 a) public pure returns (bytes4) {
85+
return ~a;
86+
}
87+
8088
function addmod8(uint8 a, uint8 b, uint8 m) public returns (uint256) {
8189
return addmod(a, b, m);
8290
}
@@ -122,6 +130,8 @@ contract C {
122130
// shr8(uint8,uint256): 1, 8 -> 0
123131
// bit(int256,int256): 6, 3 -> 2, 7, 5
124132
// bit8(int8,int8): 6, 3 -> 2, 7, 5
133+
// bnot(uint256): 0 -> 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
134+
// bnotBytes(bytes4): left(0x12345678) -> left(0xedcba987)
125135
// addmod8(uint8,uint8,uint8): 42, 25, 0 -> FAILURE, hex"4e487b71", 0x12
126136
// addmod8(uint8,uint8,uint8): 42, 25, 24 -> 19
127137
// addmod8(uint8,uint8,uint8): 200, 100, 7 -> 6

test/libsolidity/semanticTests/mlir/cf.sol

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,28 @@ contract C {
2626
}
2727
return r;
2828
}
29+
30+
function h(bool c, uint a, uint b) public pure returns (uint) {
31+
return c ? a : b;
32+
}
33+
34+
function i(bool c, bytes4 a, bytes4 b) public pure returns (bytes4) {
35+
return c ? a : b;
36+
}
37+
38+
function j(bool c, uint8 a, uint256 b) public pure returns (uint256) {
39+
return c ? a : b;
40+
}
2941
}
3042

3143
// ====
3244
// compileViaMlir: true
3345
// ----
3446
// f(uint256,uint256): 20, 10 -> 1024
3547
// g(uint256): 10 -> 22
48+
// h(bool,uint256,uint256): true, 10, 20 -> 10
49+
// h(bool,uint256,uint256): false, 10, 20 -> 20
50+
// i(bool,bytes4,bytes4): true, left(0x12345678), left(0xaabbccdd) -> left(0x12345678)
51+
// i(bool,bytes4,bytes4): false, left(0x12345678), left(0xaabbccdd) -> left(0xaabbccdd)
52+
// j(bool,uint8,uint256): true, 42, 1000 -> 42
53+
// j(bool,uint8,uint256): false, 42, 1000 -> 1000

test/libsolidity/semanticTests/mlir/logical.sol

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ contract C {
2323
x = 0;
2424
return (a || s(b), x);
2525
}
26+
27+
function j(bool a) public pure returns (bool) {
28+
return !a;
29+
}
2630
}
2731

2832
// ====
@@ -42,3 +46,5 @@ contract C {
4246
// i(bool,bool): 1, 0 -> true, 0
4347
// i(bool,bool): 0, 1 -> true, 1
4448
// i(bool,bool): 0, 0 -> false, 1
49+
// j(bool): 0 -> true
50+
// j(bool): 1 -> false

0 commit comments

Comments
 (0)