diff --git a/include/circt/Dialect/Moore/MooreOps.td b/include/circt/Dialect/Moore/MooreOps.td index 26f063b66ccd..9072f272559b 100644 --- a/include/circt/Dialect/Moore/MooreOps.td +++ b/include/circt/Dialect/Moore/MooreOps.td @@ -21,6 +21,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" // Base class for the operations in this dialect. class MooreOp traits = []> : @@ -194,9 +195,9 @@ def ProcedureOp : MooreOp<"procedure", [ } def ReturnOp : MooreOp<"return", [ - Pure, Terminator, HasParent<"ProcedureOp"> + Pure, Terminator, ParentOneOf<["ProcedureOp", "CoroutineOp"]> ]> { - let summary = "Return from a procedure"; + let summary = "Return from a procedure or coroutine"; let assemblyFormat = [{ attr-dict }]; } @@ -212,6 +213,152 @@ def UnreachableOp : MooreOp<"unreachable", [Terminator]> { let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// Coroutines +//===----------------------------------------------------------------------===// + +def CoroutineOp : MooreOp<"coroutine", [ + IsolatedFromAbove, + FunctionOpInterface, + Symbol, + RegionKindInterface, + RecursiveMemoryEffects, +]> { + let summary = "Define a coroutine"; + let description = [{ + The `moore.coroutine` op represents a SystemVerilog task. Tasks differ from + functions in that they can suspend execution via timing controls such as + `@(posedge clk)` or `#10`. This makes them coroutine-like: a call to a + coroutine suspends the calling process until the coroutine returns. + + Coroutines are `IsolatedFromAbove` and capture any external variables + explicitly as additional arguments, just like `func.func`. Any signals or + variables from the enclosing module are passed as reference arguments. + + Example: + ```mlir + moore.coroutine private @waitForClk(%clk: !moore.ref) { + moore.wait_event { + %0 = moore.read %clk : + moore.detect_event posedge %0 : l1 + } + moore.return + } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs); + let regions = (region MinSizedRegion<1>:$body); + + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "mlir::StringAttr":$sym_name, + "mlir::TypeAttr":$function_type), [{ + build($_builder, $_state, sym_name, function_type, + /*sym_visibility=*/mlir::StringAttr(), + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); + }]>, + OpBuilder<(ins "mlir::StringRef":$sym_name, + "mlir::FunctionType":$function_type), [{ + build($_builder, $_state, + $_builder.getStringAttr(sym_name), + mlir::TypeAttr::get(function_type), + /*sym_visibility=*/mlir::StringAttr(), + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); + }]>, + ]; + + let extraClassDeclaration = [{ + static mlir::RegionKind getRegionKind(unsigned index) { + return mlir::RegionKind::SSACFG; + } + + /// Returns the argument types of this coroutine. + mlir::ArrayRef getArgumentTypes() { + return getFunctionType().getInputs(); + } + + /// Returns the result types of this coroutine. + mlir::ArrayRef getResultTypes() { + return getFunctionType().getResults(); + } + + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + mlir::Region *getCallableRegion() { return &getBody(); } + }]; +} + +def CallCoroutineOp : MooreOp<"call_coroutine", [ + CallOpInterface, + DeclareOpInterfaceMethods, +]> { + let summary = "Call a coroutine"; + let description = [{ + The `moore.call_coroutine` op calls a `moore.coroutine`, which represents a + SystemVerilog task call. The calling process suspends until the coroutine + returns. This is only valid inside a procedure or another coroutine. + + Example: + ```mlir + moore.call_coroutine @waitForClk(%clk) : (!moore.ref) -> () + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type(operands, results) + }]; + + let builders = [ + OpBuilder<(ins "CoroutineOp":$coroutine, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + build($_builder, $_state, coroutine.getFunctionType().getResults(), + mlir::SymbolRefAttr::get(coroutine), operands); + }]>, + ]; + + let extraClassDeclaration = [{ + operand_range getArgOperands() { + return getOperands(); + } + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), + llvm::cast(callee)); + } + + /// CallOpInterface requires ArgAndResultAttrsOpInterface, which needs + /// methods to get/set per-argument and per-result attributes. Call sites + /// don't carry these attributes, so we stub them out as no-ops. + mlir::ArrayAttr getArgAttrsAttr() { return nullptr; } + mlir::ArrayAttr getResAttrsAttr() { return nullptr; } + void setArgAttrsAttr(mlir::ArrayAttr args) {} + void setResAttrsAttr(mlir::ArrayAttr args) {} + mlir::Attribute removeArgAttrsAttr() { return nullptr; } + mlir::Attribute removeResAttrsAttr() { return nullptr; } + }]; +} + //===----------------------------------------------------------------------===// // Declarations //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/MooreToCore/MooreToCore.cpp b/lib/Conversion/MooreToCore/MooreToCore.cpp index 07c248b1feb1..757b3f9153c2 100644 --- a/lib/Conversion/MooreToCore/MooreToCore.cpp +++ b/lib/Conversion/MooreToCore/MooreToCore.cpp @@ -516,6 +516,66 @@ struct ProcedureOpConversion : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Coroutine Conversion +//===----------------------------------------------------------------------===// + +struct CoroutineOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CoroutineOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcType = op.getFunctionType(); + TypeConverter::SignatureConversion sigConversion(funcType.getNumInputs()); + for (auto [i, type] : llvm::enumerate(funcType.getInputs())) { + auto converted = typeConverter->convertType(type); + if (!converted) + return failure(); + sigConversion.addInputs(i, converted); + } + SmallVector resultTypes; + if (failed(typeConverter->convertTypes(funcType.getResults(), resultTypes))) + return failure(); + + auto newFuncType = FunctionType::get( + rewriter.getContext(), sigConversion.getConvertedTypes(), resultTypes); + auto newOp = llhd::CoroutineOp::create(rewriter, op.getLoc(), + op.getSymName(), newFuncType); + newOp.setSymVisibilityAttr(op.getSymVisibilityAttr()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *typeConverter, + &sigConversion))) + return failure(); + + // Replace moore.return with llhd.return inside the coroutine body. + for (auto returnOp : + llvm::make_early_inc_range(newOp.getBody().getOps())) { + rewriter.setInsertionPoint(returnOp); + rewriter.replaceOpWithNewOp(returnOp, ValueRange{}); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct CallCoroutineOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CallCoroutineOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector convResTypes; + if (failed(typeConverter->convertTypes(op.getResultTypes(), convResTypes))) + return failure(); + rewriter.replaceOpWithNewOp( + op, convResTypes, adaptor.getCallee(), adaptor.getOperands()); + return success(); + } +}; + struct WaitEventOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -3115,6 +3175,8 @@ static void populateOpConversion(ConversionPatternSet &patterns, SVModuleOpConversion, InstanceOpConversion, ProcedureOpConversion, + CoroutineOpConversion, + CallCoroutineOpConversion, WaitEventOpConversion, // Patterns of shifting operations. diff --git a/lib/Dialect/Moore/MooreOps.cpp b/lib/Dialect/Moore/MooreOps.cpp index 3cefb2150f16..ca6040f62188 100644 --- a/lib/Dialect/Moore/MooreOps.cpp +++ b/lib/Dialect/Moore/MooreOps.cpp @@ -16,6 +16,7 @@ #include "circt/Dialect/Moore/MooreAttributes.h" #include "circt/Support/CustomDirectiveImpl.h" #include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" @@ -252,6 +253,67 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { } } +//===----------------------------------------------------------------------===// +// CoroutineOp +//===----------------------------------------------------------------------===// + +ParseResult CoroutineOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void CoroutineOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// CallCoroutineOp +//===----------------------------------------------------------------------===// + +LogicalResult +CallCoroutineOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto calleeName = getCalleeAttr(); + auto coroutine = + symbolTable.lookupNearestSymbolFrom(*this, calleeName); + if (!coroutine) + return emitOpError() << "'" << calleeName.getValue() + << "' does not reference a valid 'moore.coroutine'"; + + auto type = coroutine.getFunctionType(); + if (type.getNumInputs() != getNumOperands()) + return emitOpError() << "has " << getNumOperands() + << " operands, but callee expects " + << type.getNumInputs(); + + for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != type.getInput(i)) + return emitOpError() << "operand " << i << " type mismatch: expected " + << type.getInput(i) << ", got " + << getOperand(i).getType(); + + if (type.getNumResults() != getNumResults()) + return emitOpError() << "has " << getNumResults() + << " results, but callee returns " + << type.getNumResults(); + + for (unsigned i = 0, e = type.getNumResults(); i != e; ++i) + if (getResult(i).getType() != type.getResult(i)) + return emitOpError() << "result " << i << " type mismatch: expected " + << type.getResult(i) << ", got " + << getResult(i).getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // VariableOp //===----------------------------------------------------------------------===// diff --git a/test/Conversion/MooreToCore/basic.mlir b/test/Conversion/MooreToCore/basic.mlir index 44b46d4fd5f9..8b03e6a0b483 100644 --- a/test/Conversion/MooreToCore/basic.mlir +++ b/test/Conversion/MooreToCore/basic.mlir @@ -1790,3 +1790,21 @@ func.func @QueuePopFront() -> !moore.i16 { // CHECK: return [[POPPED]] : i16 return %v : !moore.i16 } + +// CHECK-LABEL: llhd.coroutine private @myTask +// CHECK-SAME: (%arg0: !llhd.ref) +moore.coroutine private @myTask(%arg0: !moore.ref) { + // CHECK: llhd.return + moore.return +} + +// CHECK-LABEL: hw.module @CoroutineLowering +moore.module @CoroutineLowering() { + %clk = moore.variable : + moore.procedure initial { + // CHECK: llhd.call_coroutine @myTask(%clk) : (!llhd.ref) -> () + moore.call_coroutine @myTask(%clk) : (!moore.ref) -> () + moore.return + } + moore.output +} diff --git a/test/Dialect/Moore/basic.mlir b/test/Dialect/Moore/basic.mlir index 08e7dee46404..5d2606282b5a 100644 --- a/test/Dialect/Moore/basic.mlir +++ b/test/Dialect/Moore/basic.mlir @@ -562,3 +562,29 @@ func.func @StringOperations(%arg0 : !moore.string, %arg1 : !moore.string) { return } + +// CHECK-LABEL: moore.coroutine @myTask(%arg0: !moore.ref) +moore.coroutine @myTask(%arg0: !moore.ref) { + // CHECK: moore.wait_event + moore.wait_event { + %0 = moore.read %arg0 : + moore.detect_event posedge %0 : l1 + } + // CHECK: moore.return + moore.return +} + +// CHECK-LABEL: moore.coroutine private @privateTask() +moore.coroutine private @privateTask() { + moore.return +} + +moore.module @CoroutineCallTest() { + %clk = moore.variable : + moore.procedure initial { + // CHECK: moore.call_coroutine @myTask(%clk) : (!moore.ref) -> () + moore.call_coroutine @myTask(%clk) : (!moore.ref) -> () + moore.return + } + moore.output +} diff --git a/test/Dialect/Moore/errors.mlir b/test/Dialect/Moore/errors.mlir index 174d06c5963f..bc981fd01fba 100644 --- a/test/Dialect/Moore/errors.mlir +++ b/test/Dialect/Moore/errors.mlir @@ -190,3 +190,32 @@ moore.union_create %0 {fieldName = "x"} : !moore.i16 -> union<{x: i32, y: i32}> %0 = moore.constant 42 : i32 // expected-error @below {{op field 'z' not found in union type}} moore.union_create %0 {fieldName = "z"} : !moore.i32 -> union<{x: i32, y: i32}> + +// ----- + +// CallCoroutineOp: callee does not exist +moore.coroutine @caller() { + // expected-error @below {{'nonexistent' does not reference a valid 'moore.coroutine'}} + moore.call_coroutine @nonexistent() : () -> () + moore.return +} + +// ----- + +// CallCoroutineOp: callee is a func.func, not a coroutine +func.func @notACoroutine() { return } +moore.coroutine @callerCoroutine() { + // expected-error @below {{'notACoroutine' does not reference a valid 'moore.coroutine'}} + moore.call_coroutine @notACoroutine() : () -> () + moore.return +} + +// ----- + +// func.call cannot target a moore.coroutine +moore.coroutine @someCoroutine() { moore.return } +func.func @funcCallingCoroutine() { + // expected-error @below {{'someCoroutine' does not reference a valid function}} + func.call @someCoroutine() : () -> () + return +}