@@ -1489,11 +1489,22 @@ Context::declareCallableImpl(const slang::ast::SubroutineSymbol &subroutine,
14891489 auto funcTy = getFunctionSignature (*this , subroutine, extraParams);
14901490 if (!funcTy)
14911491 return nullptr ;
1492- auto funcOp = mlir::func::FuncOp::create (builder, loc, qualifiedName, funcTy);
14931492
1494- SymbolTable::setSymbolVisibility (funcOp, SymbolTable::Visibility::Private);
1493+ // Create a coroutine for tasks (which can suspend) or a function for
1494+ // functions (which cannot).
1495+ Operation *funcOp;
1496+ if (subroutine.subroutineKind == slang::ast::SubroutineKind::Task) {
1497+ auto op = moore::CoroutineOp::create (builder, loc, qualifiedName, funcTy);
1498+ SymbolTable::setSymbolVisibility (op, SymbolTable::Visibility::Private);
1499+ lowering->op = op;
1500+ funcOp = op;
1501+ } else {
1502+ auto op = mlir::func::FuncOp::create (builder, loc, qualifiedName, funcTy);
1503+ SymbolTable::setSymbolVisibility (op, SymbolTable::Visibility::Private);
1504+ lowering->op = op;
1505+ funcOp = op;
1506+ }
14951507 orderedRootOps.insert (it, {subroutine.location , funcOp});
1496- lowering->op = funcOp;
14971508
14981509 // Add the function to the symbol table of the MLIR module, which uniquifies
14991510 // its name.
@@ -1506,59 +1517,66 @@ Context::declareCallableImpl(const slang::ast::SubroutineSymbol &subroutine,
15061517// / Special case handling for recursive functions with captures;
15071518// / this function fixes the in-body call of the recursive function with
15081519// / the captured arguments.
1509- static LogicalResult rewriteCallSitesToPassCaptures (mlir::func::FuncOp callee,
1510- ArrayRef<Value> captures) {
1520+ static LogicalResult
1521+ rewriteCallSitesToPassCaptures (FunctionLowering &lowering) {
1522+ auto &captures = lowering.captures ;
15111523 if (captures.empty ())
15121524 return success ();
15131525
1526+ auto *callee = lowering.op .getOperation ();
15141527 mlir::ModuleOp module = callee->getParentOfType <mlir::ModuleOp>();
15151528 if (!module )
1516- return callee .emitError (" expected callee to be nested under ModuleOp" );
1529+ return lowering. op .emitError (" expected callee to be nested under ModuleOp" );
15171530
15181531 auto usesOpt = mlir::SymbolTable::getSymbolUses (callee, module );
15191532 if (!usesOpt)
1520- return callee .emitError (" failed to compute symbol uses" );
1533+ return lowering. op .emitError (" failed to compute symbol uses" );
15211534
1522- // Snapshot the relevant users before we mutate IR.
1523- SmallVector<mlir::func::CallOp, 8 > callSites;
1524- callSites.reserve (std::distance (usesOpt->begin (), usesOpt->end ()));
1535+ // Snapshot the relevant call users before we mutate IR.
1536+ SmallVector<Operation *, 8 > callSites;
15251537 for (const mlir::SymbolTable::SymbolUse &use : *usesOpt) {
1526- if (auto call = llvm::dyn_cast<mlir::func::CallOp>(use.getUser ()))
1527- callSites.push_back (call);
1538+ auto *user = use.getUser ();
1539+ if (isa<mlir::func::CallOp>(user) || isa<moore::CallCoroutineOp>(user))
1540+ callSites.push_back (user);
15281541 }
15291542 if (callSites.empty ())
15301543 return success ();
15311544
1532- Block &entry = callee. getBody ().front ();
1545+ Block &entry = lowering. op . getFunctionBody ().front ();
15331546 const unsigned numCaps = captures.size ();
15341547 const unsigned numEntryArgs = entry.getNumArguments ();
15351548 if (numEntryArgs < numCaps)
1536- return callee .emitError (" entry block has fewer args than captures" );
1549+ return lowering. op .emitError (" entry block has fewer args than captures" );
15371550 const unsigned capArgStart = numEntryArgs - numCaps;
15381551
1539- // Current (finalized) function type.
1540- auto fTy = callee.getFunctionType ();
1552+ auto fTy = cast<FunctionType>(lowering.op .getFunctionType ());
15411553
1542- for (auto call : callSites) {
1543- SmallVector<Value> newOperands (call.getArgOperands ().begin (),
1544- call.getArgOperands ().end ());
1554+ for (auto *callOp : callSites) {
1555+ // Get the existing operands from the call.
1556+ auto argOperands = callOp->getOperands ();
1557+ SmallVector<Value> newOperands (argOperands.begin (), argOperands.end ());
15451558
1546- const bool inSameFunc = callee->isProperAncestor (call );
1559+ const bool inSameFunc = callee->isProperAncestor (callOp );
15471560 if (inSameFunc) {
1548- // Append the function’s *capture block arguments* in order.
15491561 for (unsigned i = 0 ; i < numCaps; ++i)
15501562 newOperands.push_back (entry.getArgument (capArgStart + i));
15511563 } else {
1552- // External call site: pass the captured SSA values.
15531564 newOperands.append (captures.begin (), captures.end ());
15541565 }
15551566
1556- OpBuilder b (call);
1557- auto flatRef = mlir::FlatSymbolRefAttr::get (callee);
1558- auto newCall = mlir::func::CallOp::create (
1559- b, call.getLoc (), fTy .getResults (), flatRef, newOperands);
1560- call->replaceAllUsesWith (newCall.getOperation ());
1561- call->erase ();
1567+ OpBuilder b (callOp);
1568+ auto flatRef = mlir::FlatSymbolRefAttr::get (callee->getContext (),
1569+ lowering.op .getName ());
1570+ Operation *newCall;
1571+ if (lowering.isCoroutine ()) {
1572+ newCall = moore::CallCoroutineOp::create (
1573+ b, callOp->getLoc (), fTy .getResults (), flatRef, newOperands);
1574+ } else {
1575+ newCall = mlir::func::CallOp::create (
1576+ b, callOp->getLoc (), fTy .getResults (), flatRef, newOperands);
1577+ }
1578+ callOp->replaceAllUsesWith (newCall);
1579+ callOp->erase ();
15621580 }
15631581
15641582 return success ();
@@ -1616,21 +1634,22 @@ Context::convertFunction(const slang::ast::SubroutineSymbol &subroutine) {
16161634
16171635 // Create a function body block and populate it with block arguments.
16181636 SmallVector<moore::VariableOp> argVariables;
1619- auto &block = lowering->op .getBody ().emplaceBlock ();
1637+ auto &block = lowering->op .getFunctionBody ().emplaceBlock ();
16201638
16211639 // If this is a class method, the first input is %this :
16221640 // !moore.class<@C>
16231641 if (isMethod) {
16241642 auto thisLoc = convertLocation (subroutine.location );
1625- auto thisType = lowering->op .getFunctionType ().getInput (0 );
1643+ auto thisType =
1644+ cast<FunctionType>(lowering->op .getFunctionType ()).getInput (0 );
16261645 auto thisArg = block.addArgument (thisType, thisLoc);
16271646
16281647 // Bind `this` so NamedValue/MemberAccess can find it.
16291648 valueSymbols.insert (subroutine.thisVar , thisArg);
16301649 }
16311650
16321651 // Add user-defined block arguments
1633- auto inputs = lowering->op .getFunctionType ().getInputs ();
1652+ auto inputs = cast<FunctionType>( lowering->op .getFunctionType () ).getInputs ();
16341653 auto astArgs = subroutine.getArguments ();
16351654 auto valInputs = llvm::ArrayRef<Type>(inputs).drop_front (isMethod ? 1 : 0 );
16361655
@@ -1692,7 +1711,7 @@ Context::convertFunction(const slang::ast::SubroutineSymbol &subroutine) {
16921711
16931712 // Don't capture anything that's a local reference
16941713 mlir::Region *defReg = ref.getParentRegion ();
1695- if (defReg && lowering->op .getBody ().isAncestor (defReg))
1714+ if (defReg && lowering->op .getFunctionBody ().isAncestor (defReg))
16961715 return ;
16971716
16981717 // If we've already recorded this capture, skip.
@@ -1725,7 +1744,7 @@ Context::convertFunction(const slang::ast::SubroutineSymbol &subroutine) {
17251744
17261745 // Don't capture anything that's a local reference
17271746 mlir::Region *defReg = dstRef.getParentRegion ();
1728- if (defReg && lowering->op .getBody ().isAncestor (defReg))
1747+ if (defReg && lowering->op .getFunctionBody ().isAncestor (defReg))
17291748 return ;
17301749
17311750 // If we've already recorded this capture, skip.
@@ -1759,13 +1778,15 @@ Context::convertFunction(const slang::ast::SubroutineSymbol &subroutine) {
17591778
17601779 // For the special case of recursive functions, fix the call sites within the
17611780 // body
1762- if (failed (rewriteCallSitesToPassCaptures (lowering-> op , lowering-> captures )))
1781+ if (failed (rewriteCallSitesToPassCaptures (* lowering)))
17631782 return failure ();
17641783
17651784 // If there was no explicit return statement provided by the user, insert a
17661785 // default one.
17671786 if (builder.getBlock ()) {
1768- if (returnVar && !subroutine.getReturnType ().isVoid ()) {
1787+ if (lowering->isCoroutine ()) {
1788+ moore::ReturnOp::create (builder, lowering->op .getLoc ());
1789+ } else if (returnVar && !subroutine.getReturnType ().isVoid ()) {
17691790 Value read =
17701791 moore::ReadOp::create (builder, returnVar.getLoc (), returnVar);
17711792 mlir::func::ReturnOp::create (builder, lowering->op .getLoc (), read);
@@ -1800,8 +1821,9 @@ Context::finalizeFunctionBodyCaptures(FunctionLowering &lowering) {
18001821 MLIRContext *ctx = getContext ();
18011822
18021823 // Build new input type list: existing inputs + capture ref types.
1803- SmallVector<Type> newInputs (lowering.op .getFunctionType ().getInputs ().begin (),
1804- lowering.op .getFunctionType ().getInputs ().end ());
1824+ SmallVector<Type> newInputs (
1825+ cast<FunctionType>(lowering.op .getFunctionType ()).getInputs ().begin (),
1826+ cast<FunctionType>(lowering.op .getFunctionType ()).getInputs ().end ());
18051827
18061828 for (Value cap : lowering.captures ) {
18071829 // Expect captures to be refs.
@@ -1815,11 +1837,12 @@ Context::finalizeFunctionBodyCaptures(FunctionLowering &lowering) {
18151837
18161838 // Results unchanged.
18171839 auto newFuncTy = FunctionType::get (
1818- ctx, newInputs, lowering.op .getFunctionType ().getResults ());
1819- lowering.op .setFunctionType (newFuncTy);
1840+ ctx, newInputs,
1841+ cast<FunctionType>(lowering.op .getFunctionType ()).getResults ());
1842+ lowering.op .setType (newFuncTy);
18201843
18211844 // Add the new block arguments to the entry block.
1822- Block &entry = lowering.op .getBody ().front ();
1845+ Block &entry = lowering.op .getFunctionBody ().front ();
18231846 SmallVector<Value> capArgs;
18241847 capArgs.reserve (lowering.captures .size ());
18251848 for (Type t :
@@ -2096,7 +2119,7 @@ struct ClassMethodVisitor : ClassDeclVisitorBase {
20962119 return success ();
20972120
20982121 // Grab the finalized function type from the lowered func.op.
2099- FunctionType fnTy = lowering->op .getFunctionType ();
2122+ FunctionType fnTy = cast<FunctionType>( lowering->op .getFunctionType () );
21002123 // Emit the method decl into the class body, preserving source order.
21012124 moore::ClassMethodDeclOp::create (builder, loc, fn.name , fnTy,
21022125 SymbolRefAttr::get (lowering->op ));
0 commit comments