Skip to content

Commit 3161354

Browse files
committed
[mlir] Lower logical not, bitwise not and ternary expressions
1 parent a84c6d7 commit 3161354

File tree

11 files changed

+4037
-3710
lines changed

11 files changed

+4037
-3710
lines changed

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ class SolidityToMLIRPass {
335335
/// Returns the mlir expression for the binary operation.
336336
mlir::Value genExpr(BinaryOperation const &binOp);
337337

338+
/// Returns the mlir expressions for the conditional (ternary) operation.
339+
mlir::SmallVector<mlir::Value> genExprs(Conditional const &cond);
340+
338341
/// Returns the mlir expression for the call.
339342
mlir::SmallVector<mlir::Value> genExprs(FunctionCall const &call);
340343

@@ -770,6 +773,31 @@ mlir::Value SolidityToMLIRPass::genExpr(UnaryOperation const &unaryOp) {
770773
b.create<mlir::sol::ConstantOp>(loc, b.getIntegerAttr(mlirTy, 0));
771774
return genBinExpr(Token::Sub, zero, expr, loc);
772775
}
776+
// Logical not
777+
case Token::Not: {
778+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
779+
mlir::Value zero = b.create<mlir::sol::ConstantOp>(
780+
loc, b.getIntegerAttr(expr.getType(), 0));
781+
return b.create<mlir::sol::CmpOp>(loc, mlir::sol::CmpPredicate::eq, expr,
782+
zero);
783+
}
784+
// Bitwise not (~x == x ^ -1)
785+
case Token::BitNot: {
786+
mlir::Value expr = genRValExpr(unaryOp.subExpression());
787+
788+
// Convert bytes to int if needed.
789+
if (auto bytesTy = mlir::dyn_cast<mlir::sol::BytesType>(expr.getType())) {
790+
mlir::Type intTy =
791+
b.getIntegerType(bytesTy.getSize() * 8, /*isSigned=*/false);
792+
expr = genCast(expr, intTy);
793+
}
794+
795+
auto intTy = mlir::cast<mlir::IntegerType>(expr.getType());
796+
mlir::Value allOnes = b.create<mlir::sol::ConstantOp>(
797+
loc,
798+
b.getIntegerAttr(intTy, llvm::APInt::getAllOnes(intTy.getWidth())));
799+
return b.create<mlir::sol::XorOp>(loc, expr, allOnes);
800+
}
773801
default:
774802
break;
775803
}
@@ -831,6 +859,39 @@ mlir::Value SolidityToMLIRPass::genExpr(BinaryOperation const &binOp) {
831859
return genBinExpr(binOp.getOperator(), lhs, rhs, loc);
832860
}
833861

862+
mlir::SmallVector<mlir::Value>
863+
SolidityToMLIRPass::genExprs(Conditional const &cond) {
864+
mlir::Location loc = getLoc(cond);
865+
mlir::Value condVal = genRValExpr(cond.condition());
866+
867+
// Get result types - could be single type or tuple.
868+
mlir::SmallVector<mlir::Type> resTys;
869+
if (TupleType const *tupleTy =
870+
dynamic_cast<TupleType const *>(cond.annotation().type)) {
871+
for (const Type *astTy : tupleTy->components())
872+
resTys.push_back(getType(astTy));
873+
} else {
874+
resTys.push_back(getType(cond.annotation().type));
875+
}
876+
877+
auto ifOp =
878+
b.create<mlir::scf::IfOp>(loc, resTys, condVal, /*withElse=*/true);
879+
mlir::OpBuilder::InsertionGuard guard(b);
880+
881+
// True branch
882+
b.setInsertionPointToStart(&ifOp.getThenRegion().front());
883+
mlir::SmallVector<mlir::Value> trueVals = genRValExprs(cond.trueExpression());
884+
b.create<mlir::scf::YieldOp>(loc, trueVals);
885+
886+
// False branch
887+
b.setInsertionPointToStart(&ifOp.getElseRegion().front());
888+
mlir::SmallVector<mlir::Value> falseVals =
889+
genRValExprs(cond.falseExpression());
890+
b.create<mlir::scf::YieldOp>(loc, falseVals);
891+
892+
return ifOp.getResults();
893+
}
894+
834895
mlir::Value SolidityToMLIRPass::genExpr(IndexAccess const &idxAcc) {
835896
mlir::Location loc = getLoc(idxAcc);
836897

@@ -1776,6 +1837,13 @@ mlir::Value SolidityToMLIRPass::genLValExpr(Expression const &expr) {
17761837
return {};
17771838
}
17781839

1840+
// Conditional (ternary operator)
1841+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr)) {
1842+
mlir::SmallVector<mlir::Value> exprs = genExprs(*cond);
1843+
assert(exprs.size() == 1);
1844+
return exprs[0];
1845+
}
1846+
17791847
llvm_unreachable("NYI");
17801848
}
17811849

@@ -1789,6 +1857,10 @@ SolidityToMLIRPass::genLValExprs(Expression const &expr) {
17891857
if (const auto *call = dynamic_cast<FunctionCall const *>(&expr))
17901858
return genExprs(*call);
17911859

1860+
// Conditional (ternary)
1861+
if (const auto *cond = dynamic_cast<Conditional const *>(&expr))
1862+
return genExprs(*cond);
1863+
17921864
mlir::SmallVector<mlir::Value, 1> vals;
17931865
vals.push_back(genLValExpr(expr));
17941866
return vals;

test/libsolidity/semanticTests/mlir/arith.sol

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ contract C {
7777
return (a & b, a | b, a ^ b);
7878
}
7979

80+
function bnot(uint a) public returns (uint) {
81+
return ~a;
82+
}
83+
84+
function bnotBytes(bytes4 a) public pure returns (bytes4) {
85+
return ~a;
86+
}
87+
8088
function addmod8(uint8 a, uint8 b, uint8 m) public returns (uint256) {
8189
return addmod(a, b, m);
8290
}
@@ -122,6 +130,8 @@ contract C {
122130
// shr8(uint8,uint256): 1, 8 -> 0
123131
// bit(int256,int256): 6, 3 -> 2, 7, 5
124132
// bit8(int8,int8): 6, 3 -> 2, 7, 5
133+
// bnot(uint256): 0 -> 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
134+
// bnotBytes(bytes4): left(0x12345678) -> left(0xedcba987)
125135
// addmod8(uint8,uint8,uint8): 42, 25, 0 -> FAILURE, hex"4e487b71", 0x12
126136
// addmod8(uint8,uint8,uint8): 42, 25, 24 -> 19
127137
// addmod8(uint8,uint8,uint8): 200, 100, 7 -> 6

test/libsolidity/semanticTests/mlir/cf.sol

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,22 @@ contract C {
2626
}
2727
return r;
2828
}
29+
30+
function h(bool c, uint a, uint b) public pure returns (uint) {
31+
return c ? a : b;
32+
}
33+
34+
function i(bool c, bytes4 a, bytes4 b) public pure returns (bytes4) {
35+
return c ? a : b;
36+
}
2937
}
3038

3139
// ====
3240
// compileViaMlir: true
3341
// ----
3442
// f(uint256,uint256): 20, 10 -> 1024
3543
// g(uint256): 10 -> 22
44+
// h(bool,uint256,uint256): true, 10, 20 -> 10
45+
// h(bool,uint256,uint256): false, 10, 20 -> 20
46+
// i(bool,bytes4,bytes4): true, left(0x12345678), left(0xaabbccdd) -> left(0x12345678)
47+
// i(bool,bytes4,bytes4): false, left(0x12345678), left(0xaabbccdd) -> left(0xaabbccdd)

test/libsolidity/semanticTests/mlir/logical.sol

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ contract C {
2323
x = 0;
2424
return (a || s(b), x);
2525
}
26+
27+
function j(bool a) public pure returns (bool) {
28+
return !a;
29+
}
2630
}
2731

2832
// ====
@@ -42,3 +46,5 @@ contract C {
4246
// i(bool,bool): 1, 0 -> true, 0
4347
// i(bool,bool): 0, 1 -> true, 1
4448
// i(bool,bool): 0, 0 -> false, 1
49+
// j(bool): 0 -> true
50+
// j(bool): 1 -> false

0 commit comments

Comments
 (0)