-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][EmitC] Add pass to wrap a func in class #141158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c386787
7a97f17
23e8435
0c55dcf
ff237e5
2dc3148
bf4f1cd
e7f2084
20d9f7a
8ba35d7
f3acc4f
f47f211
49f202c
d108bf2
86a057e
b064a9c
4780ab3
009f137
3bb6799
1272578
bbf6775
08f76bb
e2bcd3a
d69ca14
c278c1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1593,4 +1593,90 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects, | |
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| def EmitC_ClassOp | ||
| : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove, | ||
| OpAsmOpInterface, SymbolTable, | ||
| Symbol]#GraphRegionNoTerminator.traits> { | ||
| let summary = | ||
| "Represents a C++ class definition, encapsulating fields and methods."; | ||
|
|
||
| let description = [{ | ||
| The `emitc.class` operation defines a C++ class, acting as a container | ||
| for its data fields (`emitc.field`) and methods (`emitc.func`). | ||
| It creates a distinct scope, isolating its contents from the surrounding | ||
| MLIR region, similar to how C++ classes encapsulate their internals. | ||
|
|
||
| Example: | ||
|
||
|
|
||
| ```mlir | ||
| emitc.class @modelClass { | ||
| emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"} | ||
| emitc.func @execute() { | ||
| %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t | ||
| %1 = get_field @fieldName0 : !emitc.array<1xf32> | ||
| %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| return | ||
| } | ||
| } | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins SymbolNameAttr:$sym_name); | ||
|
|
||
| let regions = (region AnyRegion:$body); | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| // Returns the body block containing class members and methods. | ||
| Block &getBlock(); | ||
| }]; | ||
|
|
||
| let hasCustomAssemblyFormat = 1; | ||
|
|
||
| let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }]; | ||
| } | ||
|
|
||
| def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { | ||
| let summary = "A field within a class"; | ||
| let description = [{ | ||
| The `emitc.field` operation declares a named field within an `emitc.class` | ||
| operation. The field's type must be an EmitC type. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| // Example with an attribute: | ||
| emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} | ||
| // Example with no attribute: | ||
| emitc.field @fieldName0 : !emitc.array<1xf32> | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, | ||
| OptionalAttr<AnyAttr>:$attrs); | ||
|
|
||
| let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; | ||
|
|
||
| let hasVerifier = 1; | ||
| } | ||
|
|
||
| def EmitC_GetFieldOp | ||
| : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods< | ||
| SymbolUserOpInterface>]> { | ||
| let summary = "Obtain access to a field within a class instance"; | ||
| let description = [{ | ||
| The `emitc.get_field` operation retrieves the lvalue of a | ||
| named field from a given class instance. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| %0 = get_field @fieldName0 : !emitc.array<1xf32> | ||
| ``` | ||
| }]; | ||
|
|
||
| let arguments = (ins FlatSymbolRefAttr:$field_name); | ||
| let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); | ||
|
||
| let assemblyFormat = "$field_name `:` type($result) attr-dict"; | ||
| } | ||
|
|
||
| #endif // MLIR_DIALECT_EMITC_IR_EMITC | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,4 +20,42 @@ def FormExpressionsPass : Pass<"form-expressions"> { | |
| let dependentDialects = ["emitc::EmitCDialect"]; | ||
| } | ||
|
|
||
| def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { | ||
| let summary = "Wrap functions in classes, using arguments as fields."; | ||
| let description = [{ | ||
| This pass transforms `emitc.func` operations into `emitc.class` operations. | ||
| Function arguments become fields of the class, and the function body is moved | ||
| to a new `execute` method within the class. | ||
| If the corresponding function argument has attributes (accessed via `argAttrs`), | ||
| these attributes are attached to the field operation. | ||
| Otherwise, the field is created without additional attributes. | ||
|
|
||
| Example: | ||
|
|
||
| ```mlir | ||
| emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } { | ||
| %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t | ||
| %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| return | ||
| } | ||
| // becomes | ||
| emitc.class @modelClass { | ||
| emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"} | ||
| emitc.func @execute() { | ||
| %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t | ||
| %1 = get_field @input_tensor : !emitc.array<1xf32> | ||
| %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| return | ||
| } | ||
| } | ||
| ``` | ||
| }]; | ||
| let dependentDialects = ["emitc::EmitCDialect"]; | ||
| let options = [Option< | ||
| "namedAttribute", "named-attribute", "std::string", | ||
| /*default=*/"", | ||
| "Attribute key used to extract field names from function argument's " | ||
|
||
| "dictionary attributes">]; | ||
| } | ||
|
|
||
| #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1400,6 +1400,49 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { | |
| builder.getNamedAttr("id", builder.getStringAttr(id))); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // FieldOp | ||
| //===----------------------------------------------------------------------===// | ||
| LogicalResult FieldOp::verify() { | ||
|
||
| if (!isSupportedEmitCType(getType())) | ||
| return emitOpError("expected valid emitc type"); | ||
|
|
||
| Operation *parentOp = getOperation()->getParentOp(); | ||
| if (!parentOp || !isa<emitc::ClassOp>(parentOp)) | ||
| return emitOpError("field must be nested within an emitc.class operation"); | ||
|
|
||
| StringAttr symName = getSymNameAttr(); | ||
| if (!symName || symName.getValue().empty()) | ||
| return emitOpError("field must have a non-empty symbol name"); | ||
|
|
||
| if (!getAttrs()) | ||
| return success(); | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // GetFieldOp | ||
| //===----------------------------------------------------------------------===// | ||
| LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { | ||
| mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); | ||
| FieldOp fieldOp = | ||
| symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr); | ||
| if (!fieldOp) | ||
| return emitOpError("field '") | ||
| << fieldNameAttr << "' not found in the class"; | ||
|
|
||
| Type getFieldResultType = getResult().getType(); | ||
| Type fieldType = fieldOp.getType(); | ||
|
|
||
| if (fieldType != getFieldResultType) | ||
| return emitOpError("result type ") | ||
| << getFieldResultType << " does not match field '" << fieldNameAttr | ||
| << "' type " << fieldType; | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // TableGen'd op method definitions | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| //===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
| #include "mlir/Dialect/EmitC/Transforms/Passes.h" | ||
| #include "mlir/Dialect/EmitC/Transforms/Transforms.h" | ||
| #include "mlir/IR/Attributes.h" | ||
| #include "mlir/IR/Builders.h" | ||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Transforms/WalkPatternRewriteDriver.h" | ||
|
|
||
| using namespace mlir; | ||
| using namespace emitc; | ||
|
|
||
| namespace mlir { | ||
| namespace emitc { | ||
| #define GEN_PASS_DEF_WRAPFUNCINCLASSPASS | ||
| #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" | ||
|
|
||
| namespace { | ||
| struct WrapFuncInClassPass | ||
| : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> { | ||
| using WrapFuncInClassPassBase::WrapFuncInClassPassBase; | ||
| void runOnOperation() override { | ||
| Operation *rootOp = getOperation(); | ||
|
|
||
| RewritePatternSet patterns(&getContext()); | ||
| populateFuncPatterns(patterns, namedAttribute); | ||
|
|
||
| walkAndApplyPatterns(rootOp, std::move(patterns)); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
| } // namespace emitc | ||
| } // namespace mlir | ||
|
|
||
| class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> { | ||
| public: | ||
| WrapFuncInClass(MLIRContext *context, StringRef attrName) | ||
| : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {} | ||
|
||
|
|
||
| LogicalResult matchAndRewrite(emitc::FuncOp funcOp, | ||
| PatternRewriter &rewriter) const override { | ||
|
|
||
| auto className = funcOp.getSymNameAttr().str() + "Class"; | ||
| ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className); | ||
|
|
||
| SmallVector<std::pair<StringAttr, TypeAttr>> fields; | ||
| rewriter.createBlock(&newClassOp.getBody()); | ||
| rewriter.setInsertionPointToStart(&newClassOp.getBody().front()); | ||
|
|
||
| auto argAttrs = funcOp.getArgAttrs(); | ||
| for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { | ||
| StringAttr fieldName; | ||
| Attribute argAttr = nullptr; | ||
|
|
||
| fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); | ||
| if (argAttrs && idx < argAttrs->size()) | ||
| argAttr = (*argAttrs)[idx]; | ||
|
|
||
| TypeAttr typeAttr = TypeAttr::get(val.getType()); | ||
| fields.push_back({fieldName, typeAttr}); | ||
| rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr, | ||
| argAttr); | ||
|
||
| } | ||
|
|
||
| rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); | ||
| FunctionType funcType = funcOp.getFunctionType(); | ||
| Location loc = funcOp.getLoc(); | ||
| FuncOp newFuncOp = | ||
| rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType); | ||
|
|
||
| rewriter.createBlock(&newFuncOp.getBody()); | ||
| newFuncOp.getBody().takeBody(funcOp.getBody()); | ||
|
|
||
| rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); | ||
| std::vector<Value> newArguments; | ||
|
||
| newArguments.reserve(fields.size()); | ||
| for (auto &[fieldName, attr] : fields) { | ||
| GetFieldOp arg = | ||
| rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName); | ||
| newArguments.push_back(arg); | ||
| } | ||
|
|
||
| for (auto [oldArg, newArg] : | ||
| llvm::zip(newFuncOp.getArguments(), newArguments)) { | ||
| rewriter.replaceAllUsesWith(oldArg, newArg); | ||
| } | ||
|
|
||
| llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true); | ||
| if (failed(newFuncOp.eraseArguments(argsToErase))) | ||
| newFuncOp->emitOpError("failed to erase all arguments using BitVector"); | ||
|
|
||
| rewriter.replaceOp(funcOp, newClassOp); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| StringRef attributeName; | ||
| }; | ||
|
|
||
| void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns, | ||
| StringRef namedAttribute) { | ||
| patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| // RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s | ||
|
|
||
| module attributes { } { | ||
| emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, | ||
| %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"}, | ||
| %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } { | ||
| %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t | ||
| %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| %2 = load %1 : <f32> | ||
| %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| %4 = load %3 : <f32> | ||
| %5 = add %2, %4 : (f32, f32) -> f32 | ||
| %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| assign %5 : f32 to %6 : <f32> | ||
| return | ||
| } | ||
| } | ||
|
|
||
|
|
||
| // CHECK: module { | ||
| // CHECK-NEXT: emitc.class @modelClass { | ||
| // CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} | ||
| // CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} | ||
| // CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} | ||
| // CHECK-NEXT: emitc.func @execute() { | ||
| // CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32> | ||
| // CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32> | ||
| // CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32> | ||
| // CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t | ||
| // CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| // CHECK-NEXT: load {{.*}} : <f32> | ||
| // CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| // CHECK-NEXT: load {{.*}} : <f32> | ||
| // CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32 | ||
| // CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> | ||
| // CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : <f32> | ||
| // CHECK-NEXT: return | ||
| // CHECK-NEXT: } | ||
| // CHECK-NEXT: } | ||
| // CHECK-NEXT: } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| // RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s | ||
|
|
||
| emitc.func @foo(%arg0 : !emitc.array<1xf32>) { | ||
| emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> () | ||
| emitc.return | ||
| } | ||
|
|
||
| // CHECK: module { | ||
| // CHECK-NEXT: emitc.class @fooClass { | ||
| // CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> | ||
| // CHECK-NEXT: emitc.func @execute() { | ||
| // CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32> | ||
| // CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> () | ||
| // CHECK-NEXT: return | ||
| // CHECK-NEXT: } | ||
| // CHECK-NEXT: } | ||
| // CHECK-NEXT: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: space around
#to make the concat easier to see.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-format gets rid of the space around
#.do we need to ignore clang-format on this?