Skip to content

Commit 03c05fb

Browse files
authored
[mlir] Support named function return parameters (#72)
1 parent da44b02 commit 03c05fb

File tree

5 files changed

+725
-11
lines changed

5 files changed

+725
-11
lines changed

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,33 @@ void SolidityToMLIRPass::genZeroedVal(mlir::sol::AllocaOp addr) {
610610
// TODO: Do we need to zero-init here?
611611
val = b.create<mlir::sol::MallocOp>(loc, stringTy, /*zeroInit=*/false,
612612
/*size=*/mlir::Value{});
613+
} else if (auto addressTy =
614+
mlir::dyn_cast<mlir::sol::AddressType>(pointeeTy)) {
615+
auto uint160Ty = b.getIntegerType(160, /*isSigned=*/false);
616+
auto zero = b.create<mlir::sol::ConstantOp>(
617+
loc, b.getIntegerAttr(uint160Ty, llvm::APInt(160, 0)));
618+
val = genCast(zero, addressTy);
619+
} else if (auto bytesTy = mlir::dyn_cast<mlir::sol::BytesType>(pointeeTy)) {
620+
unsigned width = bytesTy.getSize() * 8;
621+
auto uintTy = b.getIntegerType(width, /*isSigned=*/false);
622+
auto zero = b.create<mlir::sol::ConstantOp>(
623+
loc, b.getIntegerAttr(uintTy, llvm::APInt(width, 0)));
624+
val = genCast(zero, bytesTy);
625+
} else if (auto contractTy =
626+
mlir::dyn_cast<mlir::sol::ContractType>(pointeeTy)) {
627+
auto uint160Ty = b.getIntegerType(160, /*isSigned=*/false);
628+
auto zero = b.create<mlir::sol::ConstantOp>(
629+
loc, b.getIntegerAttr(uint160Ty, llvm::APInt(160, 0)));
630+
auto addrTy =
631+
mlir::sol::AddressType::get(b.getContext(), contractTy.getPayable());
632+
val = genCast(genCast(zero, addrTy), contractTy);
633+
} else if (auto enumTy = mlir::dyn_cast<mlir::sol::EnumType>(pointeeTy)) {
634+
auto ui256Ty = b.getIntegerType(256, /*isSigned=*/false);
635+
auto zero = b.create<mlir::sol::ConstantOp>(
636+
loc, b.getIntegerAttr(ui256Ty, llvm::APInt(256, 0)));
637+
val = genCast(zero, enumTy);
613638
}
639+
614640
assert(val);
615641

616642
b.create<mlir::sol::StoreOp>(loc, val, addr);
@@ -2488,6 +2514,20 @@ mlir::sol::FuncOp SolidityToMLIRPass::lower(FunctionDefinition const &fn) {
24882514
b.create<mlir::sol::StoreOp>(inpLoc, arg, addr);
24892515
}
24902516

2517+
// Allocate and zero-initialize named return parameters so they can be
2518+
// loaded at implicit-return sites.
2519+
for (const auto &param : fn.returnParameters()) {
2520+
if (param->name().empty())
2521+
continue;
2522+
mlir::Location paramLoc = getLoc(*param);
2523+
mlir::Type paramTy = getType(param->annotation().type);
2524+
auto addr = b.create<mlir::sol::AllocaOp>(
2525+
paramLoc, mlir::sol::PointerType::get(b.getContext(), paramTy,
2526+
mlir::sol::DataLocation::Stack));
2527+
trackLocalVarAddr(*param, addr);
2528+
genZeroedVal(addr);
2529+
}
2530+
24912531
// Generate the call to the next ctor (if any) if `fn` is a ctor.
24922532
if (fn.isConstructor()) {
24932533
// Get base contract of `currContract`
@@ -2523,18 +2563,50 @@ mlir::sol::FuncOp SolidityToMLIRPass::lower(FunctionDefinition const &fn) {
25232563
// Lower the body.
25242564
lower(fn.body());
25252565

2526-
// Return stmt lowering generates an empty block that might be empty at this
2527-
// stage.
2528-
if (outTys.empty())
2529-
b.create<mlir::sol::ReturnOp>(getLoc(fn));
25302566
mlir::Block *currBlk = b.getBlock();
2531-
if (currBlk->empty()) {
2567+
assert(currBlk && "insertion point lost after lowering function body");
2568+
if (!currBlk->empty() &&
2569+
currBlk->back().hasTrait<mlir::OpTrait::IsTerminator>()) {
2570+
b.setInsertionPointAfter(op);
2571+
return op;
2572+
}
2573+
2574+
// A return statement lowers into a new trailing block for post-return
2575+
// code. If nothing follows the return, that block is empty and can be
2576+
// dropped.
2577+
if (currBlk->empty() && op.getBody().getBlocks().size() > 1) {
25322578
op.getBody().back().erase();
2533-
} else if (!currBlk->back().hasTrait<mlir::OpTrait::IsTerminator>()) {
2534-
// FIXME: Generate "default" return statement for non-empty return types
2535-
// (zero/zero-pointer).
2536-
llvm_unreachable("NYI");
2579+
b.setInsertionPointAfter(op);
2580+
return op;
2581+
}
2582+
2583+
// Handle void function.
2584+
if (fn.returnParameters().empty()) {
2585+
b.create<mlir::sol::ReturnOp>(fnLoc);
2586+
b.setInsertionPointAfter(op);
2587+
return op;
2588+
}
2589+
2590+
mlir::SmallVector<mlir::Value> retVals;
2591+
// Unnamed params are lowered as zero-initialized temporaries. Named
2592+
// params are loaded from their corresponding local variables.
2593+
for (const auto &param : fn.returnParameters()) {
2594+
mlir::Type paramTy = getType(param->annotation().type);
2595+
if (param->name().empty()) {
2596+
// Unnamed return param: allocate and zero-initialize a temporary, then
2597+
// load it as the return value.
2598+
auto addr = b.create<mlir::sol::AllocaOp>(
2599+
fnLoc, mlir::sol::PointerType::get(b.getContext(), paramTy,
2600+
mlir::sol::DataLocation::Stack));
2601+
genZeroedVal(addr);
2602+
retVals.push_back(b.create<mlir::sol::LoadOp>(fnLoc, addr));
2603+
} else {
2604+
// Named return param: load from its local variable address.
2605+
retVals.push_back(
2606+
b.create<mlir::sol::LoadOp>(fnLoc, getLocalVarAddr(*param)));
2607+
}
25372608
}
2609+
b.create<mlir::sol::ReturnOp>(fnLoc, retVals);
25382610

25392611
b.setInsertionPointAfter(op);
25402612
return op;
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
contract CC {}
2+
3+
contract Test {
4+
uint8[33] a;
5+
enum Color { Red, Green, Blue }
6+
7+
function f_basic() public returns (uint a) { a = 42; }
8+
9+
function f_default() public returns (uint a) {}
10+
11+
function f_multi() public returns (uint a, uint b) { a = 1; b = 2; }
12+
13+
function f_partial() public returns (uint a, uint b) { a = 7; }
14+
15+
function f_bool() public returns (bool ok) { ok = true; }
16+
17+
function f_cond(bool flag) public returns (uint result) {
18+
if (flag) { result = 10; } else { result = 20; }
19+
}
20+
21+
function f_loop(uint n) public returns (uint sum) {
22+
for (uint i = 0; i < n; i++) { sum += i; }
23+
}
24+
25+
function f_call() public returns (uint a) { a = helper(); }
26+
function helper() internal returns (uint) { return 5; }
27+
28+
function f_explicit() public returns (uint) { return 99; }
29+
30+
function f_noname(uint b) public returns (uint) { }
31+
32+
function f_int_default() public returns (int256 a) {}
33+
function f_int_neg() public returns (int256 a) { a = -5; }
34+
35+
function f_contract_default() public returns (CC c) {}
36+
37+
// Named enum return — falls off end; c must default to first enumerator (Red = 0).
38+
function f_enum_default() public returns (Color c) {}
39+
40+
// Named enum in a mixed tuple — c defaults to Red (0), unnamed slot is explicit.
41+
function f_enum_mixed_default() public returns (Color c, uint) {
42+
return (c, 5);
43+
}
44+
45+
function f_addr_default() public returns (address a) {}
46+
function f_addr_set() public returns (address a) { a = address(1); }
47+
48+
function f_bytes1_default() public returns (bytes1 a) {}
49+
function f_bytes1_set() public returns (bytes1 a) { a = 0x41; }
50+
51+
function f_bytes4_default() public returns (bytes4 a) {}
52+
53+
function f_bytes32_default() public returns (bytes32 a) {}
54+
55+
function f_early(bool flag) public returns (uint r) {
56+
r = 1;
57+
if (flag) return r;
58+
r = 2;
59+
}
60+
61+
function f_early_multi(bool flag) public returns (uint a, uint b) {
62+
a = 10; b = 20;
63+
if (flag) return (a, b);
64+
a = 1; b = 2;
65+
}
66+
67+
function f_bytes_default() public returns (bytes memory a) {}
68+
function f_bytes_set() public returns (bytes memory a) { a = "hi"; }
69+
70+
function f_str_default() public returns (string memory a) {}
71+
function f_str_set() public returns (string memory a) { a = "hello"; }
72+
73+
function f_arr_default() public returns (uint[] memory a) {}
74+
function f_arr_set() public returns (uint[] memory a) {
75+
a = new uint[](2);
76+
a[0] = 1; a[1] = 2;
77+
}
78+
79+
function f_fixed_arr_default() public returns (uint[2] memory a) {}
80+
function f_fixed_arr_set() public returns (uint[2] memory a) {
81+
a[0] = 3; a[1] = 4;
82+
}
83+
84+
function f_int_u() public returns (int256) { return -5; }
85+
function f_addr_u() public returns (address) { return address(1); }
86+
function f_bytes1_u() public returns (bytes1) { return 0x41; }
87+
88+
function f_bytes_u() public returns (bytes memory) { return "hi"; }
89+
function f_str_u() public returns (string memory) { return "hello"; }
90+
function f_arr_u() public returns (uint[] memory) {
91+
uint[] memory a = new uint[](2);
92+
a[0] = 5; a[1] = 6;
93+
return a;
94+
}
95+
96+
function f_multi_u() public returns (uint, bool) { return (42, true); }
97+
98+
// Mixed named and unnamed return parameters.
99+
// a is named (assigned, falls off end), b is unnamed (explicit return value).
100+
function f_mixed(uint x) public returns (uint a, uint) {
101+
a = x + 2;
102+
return (a, x + 1);
103+
}
104+
105+
// Mixed with default: a is named (zero default), b is unnamed (explicit).
106+
function f_mixed_default() public returns (uint a, uint) {
107+
return (a, 99);
108+
}
109+
110+
function f_arr() public returns (uint8, uint8, uint8) {
111+
a[0] = 2;
112+
a[16] = 3;
113+
a[32] = 4;
114+
return (a[0], a[16], a[32]);
115+
}
116+
}
117+
118+
// ====
119+
// compileViaMlir: true
120+
// ----
121+
// f_basic() -> 42
122+
// f_default() -> 0
123+
// f_noname(uint256): 7 -> 0
124+
// f_multi() -> 1, 2
125+
// f_partial() -> 7, 0
126+
// f_bool() -> true
127+
// f_cond(bool): true -> 10
128+
// f_cond(bool): false -> 20
129+
// f_loop(uint256): 5 -> 10
130+
// f_call() -> 5
131+
// f_explicit() -> 99
132+
// f_int_default() -> 0
133+
// f_int_neg() -> -5
134+
// f_contract_default() -> 0
135+
// f_enum_default() -> 0
136+
// f_enum_mixed_default() -> 0, 5
137+
// f_addr_default() -> 0
138+
// f_addr_set() -> 1
139+
// f_bytes1_default() -> 0
140+
// f_bytes1_set() -> left(0x41)
141+
// f_bytes4_default() -> 0
142+
// f_bytes32_default() -> 0
143+
// f_early(bool): true -> 1
144+
// f_early(bool): false -> 2
145+
// f_early_multi(bool): true -> 10, 20
146+
// f_early_multi(bool): false -> 1, 2
147+
// f_bytes_default() -> 32, 0
148+
// f_bytes_set() -> 32, 2, "hi"
149+
// f_str_default() -> 32, 0
150+
// f_str_set() -> 32, 5, "hello"
151+
// f_arr_default() -> 32, 0
152+
// f_arr_set() -> 32, 2, 1, 2
153+
// f_fixed_arr_default() -> 0, 0
154+
// f_fixed_arr_set() -> 3, 4
155+
// f_int_u() -> -5
156+
// f_addr_u() -> 1
157+
// f_bytes1_u() -> left(0x41)
158+
// f_bytes_u() -> 32, 2, "hi"
159+
// f_str_u() -> 32, 5, "hello"
160+
// f_arr_u() -> 32, 2, 5, 6
161+
// f_multi_u() -> 42, true
162+
// f_mixed(uint256): 5 -> 7, 6
163+
// f_mixed(uint256): 0 -> 2, 1
164+
// f_mixed_default() -> 0, 99
165+
// f_arr() -> 2, 3, 4

0 commit comments

Comments
 (0)