@@ -534,7 +534,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
534
534
}
535
535
}
536
536
537
- auto getLLVM = [&](Expr *E) -> mlir::Value {
537
+ auto getLLVM = [&](Expr *E, bool isRef = false ) -> mlir::Value {
538
538
auto sub = Visit (E);
539
539
if (!sub.val ) {
540
540
expr->dump ();
@@ -564,23 +564,46 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
564
564
auto shape = std::vector<int64_t >(mt.getShape ());
565
565
assert (shape.size () == 2 );
566
566
567
- OpBuilder abuilder (builder.getContext ());
568
- abuilder.setInsertionPointToStart (allocationScope);
569
- auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
570
- auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
571
- loc,
567
+ auto PT =
572
568
LLVM::LLVMPointerType::get (Glob.typeTranslator .translateType (
573
569
anonymize (getLLVMType (E->getType ()))),
574
- 0 ),
575
- one, 0 );
576
- ValueCategory (alloc, /* isRef*/ true )
577
- .store (loc, builder, sub, /* isArray*/ isArray);
578
- sub = ValueCategory (alloc, /* isRef*/ true );
570
+ 0 );
571
+ if (true ) {
572
+ sub = ValueCategory (
573
+ builder.create <polygeist::Memref2PointerOp>(loc, PT, sub.val ),
574
+ sub.isReference );
575
+ } else {
576
+ OpBuilder abuilder (builder.getContext ());
577
+ abuilder.setInsertionPointToStart (allocationScope);
578
+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
579
+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(loc, PT, one, 0 );
580
+ ValueCategory (alloc, /* isRef*/ true )
581
+ .store (loc, builder, sub, /* isArray*/ isArray);
582
+ sub = ValueCategory (alloc, /* isRef*/ true );
583
+ }
584
+ }
585
+ mlir::Value val;
586
+ clang::QualType ct;
587
+ if (!isRef) {
588
+ val = sub.getValue (loc, builder);
589
+ ct = E->getType ();
590
+ } else {
591
+ if (!sub.isReference ) {
592
+ OpBuilder abuilder (builder.getContext ());
593
+ abuilder.setInsertionPointToStart (allocationScope);
594
+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
595
+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
596
+ loc, LLVM::LLVMPointerType::get (sub.val .getType ()), one, 0 );
597
+ ValueCategory (alloc, /* isRef*/ true )
598
+ .store (loc, builder, sub, /* isArray*/ isArray);
599
+ sub = ValueCategory (alloc, /* isRef*/ true );
600
+ }
601
+ assert (sub.isReference );
602
+ val = sub.val ;
603
+ ct = Glob.CGM .getContext ().getLValueReferenceType (E->getType ());
579
604
}
580
- auto val = sub.getValue (loc, builder);
581
605
if (auto mt = val.getType ().dyn_cast <MemRefType>()) {
582
- auto nt = Glob.typeTranslator
583
- .translateType (anonymize (getLLVMType (E->getType ())))
606
+ auto nt = Glob.typeTranslator .translateType (anonymize (getLLVMType (ct)))
584
607
.cast <LLVM::LLVMPointerType>();
585
608
val = builder.create <polygeist::Memref2PointerOp>(loc, nt, val);
586
609
}
@@ -1483,7 +1506,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
1483
1506
1484
1507
std::vector<mlir::Value> args;
1485
1508
for (auto *a : expr->arguments ()) {
1486
- args.push_back (getLLVM (a));
1509
+ args.push_back (getLLVM (a, /* isRef */ false ));
1487
1510
}
1488
1511
mlir::Value called;
1489
1512
@@ -1492,7 +1515,8 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
1492
1515
called = builder.create <mlir::LLVM::CallOp>(loc, strcmpF, args)
1493
1516
.getResult ();
1494
1517
} else {
1495
- args.insert (args.begin (), getLLVM (expr->getCallee ()));
1518
+ args.insert (args.begin (),
1519
+ getLLVM (expr->getCallee (), /* isRef*/ false ));
1496
1520
SmallVector<mlir::Type> RTs = {Glob.typeTranslator .translateType (
1497
1521
anonymize (getLLVMType (expr->getType ())))};
1498
1522
if (RTs[0 ].isa <LLVM::LLVMVoidType>())
@@ -1509,31 +1533,154 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
1509
1533
if (!callee || callee->isVariadic ()) {
1510
1534
bool isReference = expr->isLValue () || expr->isXValue ();
1511
1535
std::vector<mlir::Value> args;
1512
- for (auto *a : expr->arguments ()) {
1513
- args.push_back (getLLVM (a));
1514
- }
1515
1536
mlir::Value called;
1516
1537
if (callee) {
1517
1538
auto strcmpF = Glob.GetOrCreateLLVMFunction (callee);
1539
+ std::vector<clang::QualType> types;
1540
+ if (auto CC = dyn_cast<CXXMethodDecl>(callee)) {
1541
+ types.push_back (CC->getThisType ());
1542
+ }
1543
+ for (auto parm : callee->parameters ()) {
1544
+ types.push_back (parm->getOriginalType ());
1545
+ }
1546
+ int i = 0 ;
1547
+ for (auto *a : expr->arguments ()) {
1548
+ bool isRef = false ;
1549
+ if (i < types.size ())
1550
+ isRef = types[i]->isReferenceType ();
1551
+ i++;
1552
+ args.push_back (getLLVM (a, isRef));
1553
+ }
1518
1554
called =
1519
1555
builder.create <mlir::LLVM::CallOp>(loc, strcmpF, args).getResult ();
1520
1556
} else {
1521
- args.insert (args.begin (), getLLVM (expr->getCallee ()));
1557
+ mlir::Value fn = Visit (expr->getCallee ()).getValue (loc, builder);
1558
+ if (auto MT = fn.getType ().dyn_cast <MemRefType>()) {
1559
+ fn = builder.create <polygeist::Memref2PointerOp>(
1560
+ loc, LLVM::LLVMPointerType::get (MT.getElementType (), 0 ), fn);
1561
+ }
1562
+ auto PTF = fn.getType ()
1563
+ .cast <LLVM::LLVMPointerType>()
1564
+ .getElementType ()
1565
+ .cast <LLVM::LLVMFunctionType>();
1566
+ SmallVector<mlir::Type, 1 > argtys;
1567
+ bool needsChange = false ;
1568
+ for (auto FT : PTF.getParams ()) {
1569
+ if (auto mt = FT.dyn_cast <MemRefType>()) {
1570
+ argtys.push_back (LLVM::LLVMPointerType::get (mt.getElementType (), 0 ));
1571
+ needsChange = true ;
1572
+ } else
1573
+ argtys.push_back (FT);
1574
+ }
1575
+ auto rt = PTF.getReturnType ();
1576
+ if (auto mt = rt.dyn_cast <MemRefType>()) {
1577
+ rt = LLVM::LLVMPointerType::get (mt.getElementType (), 0 );
1578
+ needsChange = true ;
1579
+ }
1580
+ if (needsChange)
1581
+ fn = builder.create <LLVM::BitcastOp>(
1582
+ loc,
1583
+ LLVM::LLVMPointerType::get (
1584
+ LLVM::LLVMFunctionType::get (rt, argtys, PTF.isVarArg ()), 0 ),
1585
+ fn);
1586
+
1587
+ args.push_back (fn);
1522
1588
auto CT = expr->getType ();
1523
- if (isReference)
1524
- CT = Glob.CGM .getContext ().getLValueReferenceType (CT);
1525
- SmallVector<mlir::Type> RTs = {
1526
- Glob. typeTranslator . translateType ( anonymize ( getLLVMType (CT)) )};
1589
+ // if (isReference)
1590
+ // CT = Glob.CGM.getContext().getLValueReferenceType(CT);
1591
+ SmallVector<mlir::Type> RTs = {rt};
1592
+ // getMLIRType(CT )};
1527
1593
1528
1594
auto ft = args[0 ]
1529
1595
.getType ()
1530
1596
.cast <LLVM::LLVMPointerType>()
1531
1597
.getElementType ()
1532
1598
.cast <LLVM::LLVMFunctionType>();
1533
- assert (RTs[0 ] == ft.getReturnType ());
1534
- if (RTs[0 ].isa <LLVM::LLVMVoidType>())
1599
+ auto ETy = expr->getCallee ()->getType ()->getUnqualifiedDesugaredType ();
1600
+ ETy = cast<clang::PointerType>(ETy)
1601
+ ->getPointeeType ()
1602
+ ->getUnqualifiedDesugaredType ();
1603
+ auto CFT = dyn_cast<clang::FunctionProtoType>(ETy);
1604
+ std::vector<clang::QualType> types;
1605
+ if (CFT) {
1606
+ for (auto t : CFT->getParamTypes ())
1607
+ types.push_back (t);
1608
+ } else {
1609
+ assert (isa<clang::FunctionNoProtoType>(ETy));
1610
+ }
1611
+
1612
+ auto ETy2 = ETy->getCanonicalTypeUnqualified ();
1613
+
1614
+ const clang::CodeGen::CGFunctionInfo *FI;
1615
+ if (const FunctionProtoType *FPT = dyn_cast<FunctionProtoType>(ETy2)) {
1616
+ FI = &Glob.CGM .getTypes ().arrangeFreeFunctionType (
1617
+ CanQual<FunctionProtoType>::CreateUnsafe (QualType (FPT, 0 )));
1618
+ } else {
1619
+ const FunctionNoProtoType *FNPT = cast<FunctionNoProtoType>(ETy2);
1620
+ FI = &Glob.CGM .getTypes ().arrangeFreeFunctionType (
1621
+ CanQual<FunctionNoProtoType>::CreateUnsafe (QualType (FNPT, 0 )));
1622
+ }
1623
+
1624
+ int i = 0 ;
1625
+ for (auto *a : expr->arguments ()) {
1626
+ bool isRef = false ;
1627
+ bool isArray = false ;
1628
+ if (i < types.size ()) {
1629
+ isRef = types[i]->isReferenceType ();
1630
+ // auto inf = FI->arguments()[i].info;
1631
+ // isRef |= inf.isIndirect();
1632
+ Glob.getMLIRType (types[i], &isArray);
1633
+ isRef |= isArray;
1634
+ }
1635
+
1636
+ auto sub = Visit (a);
1637
+ mlir::Value v;
1638
+ if (isRef) {
1639
+ if (!sub.isReference ) {
1640
+ OpBuilder abuilder (builder.getContext ());
1641
+ abuilder.setInsertionPointToStart (allocationScope);
1642
+ auto one = abuilder.create <ConstantIntOp>(loc, 1 , 64 );
1643
+ auto alloc = abuilder.create <mlir::LLVM::AllocaOp>(
1644
+ loc, LLVM::LLVMPointerType::get (sub.val .getType ()), one, 0 );
1645
+ ValueCategory (alloc, /* isRef*/ true )
1646
+ .store (loc, builder, sub, /* isArray*/ false );
1647
+ sub = ValueCategory (alloc, /* isRef*/ true );
1648
+ }
1649
+ assert (sub.isReference );
1650
+ v = sub.val ;
1651
+ } else {
1652
+ v = sub.getValue (loc, builder);
1653
+ }
1654
+ if (i < FI->arg_size ()) {
1655
+ // TODO expand full calling conv
1656
+ /*
1657
+ auto inf = FI->arguments()[i].info;
1658
+ if (inf.isIgnore() || inf.isInAlloca()) {
1659
+ i++;
1660
+ continue;
1661
+ }
1662
+ if (inf.isExpand()) {
1663
+ i++;
1664
+ continue;
1665
+ }
1666
+ */
1667
+ }
1668
+ i++;
1669
+ if (auto mt = v.getType ().dyn_cast <MemRefType>()) {
1670
+ v = builder.create <polygeist::Memref2PointerOp>(
1671
+ loc, LLVM::LLVMPointerType::get (mt.getElementType (), 0 ), v);
1672
+ }
1673
+ args.push_back (v);
1674
+ }
1675
+ if (RTs[0 ].isa <mlir::NoneType>() || RTs[0 ].isa <LLVM::LLVMVoidType>())
1535
1676
RTs.clear ();
1677
+ else
1678
+ assert (RTs[0 ] == ft.getReturnType ());
1536
1679
called = builder.create <mlir::LLVM::CallOp>(loc, RTs, args).getResult ();
1680
+ if (PTF.getReturnType () != ft.getReturnType ()) {
1681
+ called = builder.create <polygeist::Pointer2MemrefOp>(
1682
+ loc, PTF.getReturnType (), called);
1683
+ }
1537
1684
}
1538
1685
if (isReference) {
1539
1686
if (!(called.getType ().isa <LLVM::LLVMPointerType>() ||
0 commit comments