Skip to content

Commit 2596e32

Browse files
committed
[mlir] Impl address(<contract>).code lowering (incl extcodeopy)
1 parent 1f36ed6 commit 2596e32

File tree

13 files changed

+471
-18
lines changed

13 files changed

+471
-18
lines changed

libsolidity/codegen/mlir/Sol/SolOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,17 @@ void EmitOp::print(OpAsmPrinter &p) {
290290
p << " : " << getArgs().getTypes();
291291
}
292292

293+
//===----------------------------------------------------------------------===//
294+
// CodeOp
295+
//===----------------------------------------------------------------------===//
296+
297+
void CodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
298+
Value contAddr) {
299+
Type resTy =
300+
sol::StringType::get(odsBuilder.getContext(), sol::DataLocation::Memory);
301+
build(odsBuilder, odsState, resTy, contAddr);
302+
}
303+
293304
//===----------------------------------------------------------------------===//
294305
// IfOp
295306
//===----------------------------------------------------------------------===//

libsolidity/codegen/mlir/Sol/SolOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,12 +634,27 @@ def Sol_NewOp : Sol_Op<"new", [AttrSizedOperandSegments]> {
634634
Optional<UI256>:$salt,
635635
Variadic<AnyType>:$ctorArgs);
636636
let results = (outs Sol_AddressType:$out);
637+
637638
let assemblyFormat = [{
638639
$objName `value` `=` $val (`salt` `=` $salt^)?
639640
`ctor` `(` ($ctorArgs^ `:` type($ctorArgs))? `)` attr-dict `:` type($out)
640641
}];
641642
}
642643

644+
def Sol_CodeOp : Sol_Op<"code"> {
645+
let arguments = (ins Sol_AddressType:$contAddr);
646+
// FIXME: AnyType -> Sol_StringType causes the asm printer to miss the prefix!
647+
let results = (outs AnyType:$out);
648+
649+
let builders = [
650+
OpBuilder<(ins "Value":$contAddr)>
651+
];
652+
653+
let assemblyFormat = [{
654+
$contAddr attr-dict `:` type($contAddr) `->` type($out)
655+
}];
656+
}
657+
643658
def Sol_ReturnOp : Sol_Op<"return", [Pure, Terminator]> {
644659
// TODO: Verifier should check if a func op is an ancestor. ParentOneOf
645660
// doesn't work for this.

libsolidity/codegen/mlir/SolToStandardPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ struct ConvertSolToStandard
159159
sol::DecodeOp,
160160
sol::ExtCallOp,
161161
sol::NewOp,
162+
sol::CodeOp,
162163
sol::EmitOp,
163164
sol::RequireOp,
164165
sol::ConvCastOp,

libsolidity/codegen/mlir/SolidityToMLIR.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ mlir::Type SolidityToMLIRPass::getType(Type const *ty) {
472472

473473
return b.getFunctionType(inTys, outTys);
474474
}
475+
case Type::Category::Contract:
476+
// FIXME: 256 -> 160
477+
return b.getIntegerType(256, /*isSigned=*/false);
475478
default:
476479
break;
477480
}
@@ -752,10 +755,15 @@ mlir::Value SolidityToMLIRPass::genExpr(MemberAccess const &memberAcc) {
752755
/*numBits=*/64, loc);
753756
return b.create<mlir::sol::GepOp>(loc, genRValExpr(memberAcc.expression()),
754757
memberIdx);
755-
break;
756758
}
757759
default:
758760
break;
761+
case Type::Category::Address: {
762+
if (memberName == "code") {
763+
return b.create<mlir::sol::CodeOp>(loc,
764+
genRValExpr(memberAcc.expression()));
765+
}
766+
}
759767
}
760768

761769
llvm_unreachable("NYI");

libsolidity/codegen/mlir/Target/EVM/SolToYul.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,28 @@ struct NewOpLowering : public OpConversionPattern<sol::NewOp> {
12831283
}
12841284
};
12851285

1286+
struct CodeOpLowering : public OpConversionPattern<sol::CodeOp> {
1287+
using OpConversionPattern<sol::CodeOp>::OpConversionPattern;
1288+
1289+
LogicalResult matchAndRewrite(sol::CodeOp op, OpAdaptor adaptor,
1290+
ConversionPatternRewriter &r) const override {
1291+
1292+
Location loc = op.getLoc();
1293+
solidity::mlirgen::BuilderExt bExt(r, loc);
1294+
evm::Builder evmB(r, loc);
1295+
1296+
auto extCodeSize = r.create<yul::ExtCodeSizeOp>(loc, adaptor.getContAddr());
1297+
Value alloc = evmB.genMemAlloc(op.getType(), /*zeroInit=*/false,
1298+
/*initVals=*/{}, extCodeSize);
1299+
auto codeAddr = evmB.genDataAddrPtr(alloc, sol::DataLocation::Memory);
1300+
r.create<yul::ExtCodeCopyOp>(
1301+
loc, adaptor.getContAddr(), /*dstOffset=*/codeAddr,
1302+
/*srcOffset=*/bExt.genI256Const(0), extCodeSize);
1303+
r.replaceOp(op, alloc);
1304+
return success();
1305+
}
1306+
};
1307+
12861308
struct TryOpLowering : public OpConversionPattern<sol::TryOp> {
12871309
using OpConversionPattern<sol::TryOp>::OpConversionPattern;
12881310

@@ -2231,8 +2253,8 @@ void evm::populateAbiPats(mlir::RewritePatternSet &pats,
22312253
}
22322254

22332255
void evm::populateExtCallPat(RewritePatternSet &pats, TypeConverter &tyConv) {
2234-
pats.add<ExtCallOpLowering, TryOpLowering, NewOpLowering>(tyConv,
2235-
pats.getContext());
2256+
pats.add<ExtCallOpLowering, TryOpLowering, NewOpLowering, CodeOpLowering>(
2257+
tyConv, pats.getContext());
22362258
}
22372259

22382260
void evm::populateEmitPat(RewritePatternSet &pats, TypeConverter &tyConv) {

libsolidity/codegen/mlir/Target/EVM/YulToStandard.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,23 @@ struct ExtCodeSizeOpLowering : public OpRewritePattern<yul::ExtCodeSizeOp> {
315315
}
316316
};
317317

318+
struct ExtCodeCopyOpLowering : public OpRewritePattern<yul::ExtCodeCopyOp> {
319+
using OpRewritePattern<yul::ExtCodeCopyOp>::OpRewritePattern;
320+
321+
LogicalResult matchAndRewrite(yul::ExtCodeCopyOp op,
322+
PatternRewriter &r) const override {
323+
evm::Builder evmB(r, op.getLoc());
324+
325+
r.replaceOpWithNewOp<LLVM::IntrCallOp>(
326+
op, llvm::Intrinsic::evm_extcodecopy, /*resTy=*/Type{}, /*ins=*/
327+
ValueRange{op.getAddr(),
328+
/*dst=*/evmB.genHeapPtr(op.getDst()),
329+
/*src=*/evmB.genCodePtr(op.getSrc()), op.getSize()},
330+
"evm.extcodecopy");
331+
return success();
332+
}
333+
};
334+
318335
struct CreateOpLowering : public OpRewritePattern<yul::CreateOp> {
319336
using OpRewritePattern<yul::CreateOp>::OpRewritePattern;
320337

@@ -699,6 +716,7 @@ void evm::populateYulPats(RewritePatternSet &pats) {
699716
CodeSizeOpLowering,
700717
CodeCopyOpLowering,
701718
ExtCodeSizeOpLowering,
719+
ExtCodeCopyOpLowering,
702720
CreateOpLowering,
703721
Create2OpLowering,
704722
MLoadOpLowering,

libsolidity/codegen/mlir/Yul/YulOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,12 @@ def Yul_ExtCodeSizeOp : Yul_Op<"extcodesize", [Pure]> {
254254
let assemblyFormat = "$addr attr-dict";
255255
}
256256

257+
def Yul_ExtCodeCopyOp : Yul_Op<"extcodecopy"> {
258+
let summary = "Represents the `extcodecopy` call in yul";
259+
let arguments = (ins I256:$addr, I256:$dst, I256:$src, I256:$size);
260+
let assemblyFormat = "$addr `,` $dst `,` $src `,` $size attr-dict";
261+
}
262+
257263
def Yul_CreateOp : Yul_Op<"create"> {
258264
let summary = "Represents the `create` call in yul";
259265
let arguments = (ins I256:$val, I256:$addr, I256:$size);

libsolidity/codegen/mlir/YulToMLIR.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ void YulToMLIRPass::populateBuiltinGenMap() {
308308
defSimpleBuiltinGenNoRet<CodeCopyOp>("codecopy");
309309
defSimpleBuiltinGen<CodeSizeOp>("codesize");
310310
defSimpleBuiltinGen<ExtCodeSizeOp>("extcodesize");
311+
defSimpleBuiltinGenNoRet<ExtCodeCopyOp>("extcodecopy");
311312
defSimpleBuiltinGen<CreateOp>("create");
312313
defSimpleBuiltinGen<Create2Op>("create2");
313314
defSimpleBuiltinGen<SLoadOp>("sload");
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
contract C {
2+
function f() public returns (uint) { return 42; }
3+
}
4+
5+
contract D {
6+
function g() public returns (uint) {
7+
C c1 = new C();
8+
C c2 = c1;
9+
require(address(c1).code.length > 50);
10+
return 1;
11+
}
12+
}
13+
14+
// ====
15+
// compileViaMlir: true
16+
// ----
17+
// g() -> 1

0 commit comments

Comments
 (0)