Skip to content

Commit e615544

Browse files
authored
[mlir][EmitC] Add pass to wrap a func in class (#141158)
Goal: Enable using C++ classes to AOT compile models for MLGO. This commit introduces a transformation pass that converts standalone `emitc.func` operations into `emitc.class `structures to support class-based C++ code generation for MLGO. Transformation details: - Wrap `emitc.func @func_name` into `emitc.class @Myfunc_nameClass` - Converts function arguments to class fields with preserved attributes - Transforms function body into an `execute()` method with no arguments - Replaces argument references with `get_field` operations Before: emitc.func @model(%arg0, %arg1, %arg2) with direct argument access After: emitc.class with fields and execute() method using get_field operations This enables generating C++ classes that can be instantiated and executed as self-contained model objects for AOT compilation workflows.
1 parent e1cd450 commit e615544

File tree

9 files changed

+342
-0
lines changed

9 files changed

+342
-0
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,4 +1593,90 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
15931593
let hasVerifier = 1;
15941594
}
15951595

1596+
def EmitC_ClassOp
1597+
: EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
1598+
OpAsmOpInterface, SymbolTable,
1599+
Symbol]#GraphRegionNoTerminator.traits> {
1600+
let summary =
1601+
"Represents a C++ class definition, encapsulating fields and methods.";
1602+
1603+
let description = [{
1604+
The `emitc.class` operation defines a C++ class, acting as a container
1605+
for its data fields (`emitc.field`) and methods (`emitc.func`).
1606+
It creates a distinct scope, isolating its contents from the surrounding
1607+
MLIR region, similar to how C++ classes encapsulate their internals.
1608+
1609+
Example:
1610+
1611+
```mlir
1612+
emitc.class @modelClass {
1613+
emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"}
1614+
emitc.func @execute() {
1615+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1616+
%1 = get_field @fieldName0 : !emitc.array<1xf32>
1617+
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1618+
return
1619+
}
1620+
}
1621+
```
1622+
}];
1623+
1624+
let arguments = (ins SymbolNameAttr:$sym_name);
1625+
1626+
let regions = (region AnyRegion:$body);
1627+
1628+
let extraClassDeclaration = [{
1629+
// Returns the body block containing class members and methods.
1630+
Block &getBlock();
1631+
}];
1632+
1633+
let hasCustomAssemblyFormat = 1;
1634+
1635+
let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
1636+
}
1637+
1638+
def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
1639+
let summary = "A field within a class";
1640+
let description = [{
1641+
The `emitc.field` operation declares a named field within an `emitc.class`
1642+
operation. The field's type must be an EmitC type.
1643+
1644+
Example:
1645+
1646+
```mlir
1647+
// Example with an attribute:
1648+
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
1649+
// Example with no attribute:
1650+
emitc.field @fieldName0 : !emitc.array<1xf32>
1651+
```
1652+
}];
1653+
1654+
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
1655+
OptionalAttr<AnyAttr>:$attrs);
1656+
1657+
let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
1658+
1659+
let hasVerifier = 1;
1660+
}
1661+
1662+
def EmitC_GetFieldOp
1663+
: EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods<
1664+
SymbolUserOpInterface>]> {
1665+
let summary = "Obtain access to a field within a class instance";
1666+
let description = [{
1667+
The `emitc.get_field` operation retrieves the lvalue of a
1668+
named field from a given class instance.
1669+
1670+
Example:
1671+
1672+
```mlir
1673+
%0 = get_field @fieldName0 : !emitc.array<1xf32>
1674+
```
1675+
}];
1676+
1677+
let arguments = (ins FlatSymbolRefAttr:$field_name);
1678+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
1679+
let assemblyFormat = "$field_name `:` type($result) attr-dict";
1680+
}
1681+
15961682
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace mlir {
1515
namespace emitc {
1616

1717
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
18+
#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
1819
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
1920

2021
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,42 @@ def FormExpressionsPass : Pass<"form-expressions"> {
2020
let dependentDialects = ["emitc::EmitCDialect"];
2121
}
2222

23+
def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
24+
let summary = "Wrap functions in classes, using arguments as fields.";
25+
let description = [{
26+
This pass transforms `emitc.func` operations into `emitc.class` operations.
27+
Function arguments become fields of the class, and the function body is moved
28+
to a new `execute` method within the class.
29+
If the corresponding function argument has attributes (accessed via `argAttrs`),
30+
these attributes are attached to the field operation.
31+
Otherwise, the field is created without additional attributes.
32+
33+
Example:
34+
35+
```mlir
36+
emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } {
37+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
38+
%1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
39+
return
40+
}
41+
// becomes
42+
emitc.class @modelClass {
43+
emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
44+
emitc.func @execute() {
45+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
46+
%1 = get_field @input_tensor : !emitc.array<1xf32>
47+
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
48+
return
49+
}
50+
}
51+
```
52+
}];
53+
let dependentDialects = ["emitc::EmitCDialect"];
54+
let options = [Option<
55+
"namedAttribute", "named-attribute", "std::string",
56+
/*default=*/"",
57+
"Attribute key used to extract field names from function argument's "
58+
"dictionary attributes">];
59+
}
60+
2361
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
2828
/// Populates `patterns` with expression-related patterns.
2929
void populateExpressionPatterns(RewritePatternSet &patterns);
3030

31+
/// Populates 'patterns' with func-related patterns.
32+
void populateFuncPatterns(RewritePatternSet &patterns,
33+
StringRef namedAttribute);
34+
3135
} // namespace emitc
3236
} // namespace mlir
3337

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,49 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
14001400
builder.getNamedAttr("id", builder.getStringAttr(id)));
14011401
}
14021402

1403+
//===----------------------------------------------------------------------===//
1404+
// FieldOp
1405+
//===----------------------------------------------------------------------===//
1406+
LogicalResult FieldOp::verify() {
1407+
if (!isSupportedEmitCType(getType()))
1408+
return emitOpError("expected valid emitc type");
1409+
1410+
Operation *parentOp = getOperation()->getParentOp();
1411+
if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1412+
return emitOpError("field must be nested within an emitc.class operation");
1413+
1414+
StringAttr symName = getSymNameAttr();
1415+
if (!symName || symName.getValue().empty())
1416+
return emitOpError("field must have a non-empty symbol name");
1417+
1418+
if (!getAttrs())
1419+
return success();
1420+
1421+
return success();
1422+
}
1423+
1424+
//===----------------------------------------------------------------------===//
1425+
// GetFieldOp
1426+
//===----------------------------------------------------------------------===//
1427+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1428+
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1429+
FieldOp fieldOp =
1430+
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
1431+
if (!fieldOp)
1432+
return emitOpError("field '")
1433+
<< fieldNameAttr << "' not found in the class";
1434+
1435+
Type getFieldResultType = getResult().getType();
1436+
Type fieldType = fieldOp.getType();
1437+
1438+
if (fieldType != getFieldResultType)
1439+
return emitOpError("result type ")
1440+
<< getFieldResultType << " does not match field '" << fieldNameAttr
1441+
<< "' type " << fieldType;
1442+
1443+
return success();
1444+
}
1445+
14031446
//===----------------------------------------------------------------------===//
14041447
// TableGen'd op method definitions
14051448
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
22
Transforms.cpp
33
FormExpressions.cpp
44
TypeConversions.cpp
5+
WrapFuncInClass.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
10+
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
11+
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
12+
#include "mlir/IR/Attributes.h"
13+
#include "mlir/IR/Builders.h"
14+
#include "mlir/IR/BuiltinAttributes.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
17+
18+
using namespace mlir;
19+
using namespace emitc;
20+
21+
namespace mlir {
22+
namespace emitc {
23+
#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
24+
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
25+
26+
namespace {
27+
struct WrapFuncInClassPass
28+
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
29+
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
30+
void runOnOperation() override {
31+
Operation *rootOp = getOperation();
32+
33+
RewritePatternSet patterns(&getContext());
34+
populateFuncPatterns(patterns, namedAttribute);
35+
36+
walkAndApplyPatterns(rootOp, std::move(patterns));
37+
}
38+
};
39+
40+
} // namespace
41+
} // namespace emitc
42+
} // namespace mlir
43+
44+
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
45+
public:
46+
WrapFuncInClass(MLIRContext *context, StringRef attrName)
47+
: OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
48+
49+
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
50+
PatternRewriter &rewriter) const override {
51+
52+
auto className = funcOp.getSymNameAttr().str() + "Class";
53+
ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
54+
55+
SmallVector<std::pair<StringAttr, TypeAttr>> fields;
56+
rewriter.createBlock(&newClassOp.getBody());
57+
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
58+
59+
auto argAttrs = funcOp.getArgAttrs();
60+
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
61+
StringAttr fieldName;
62+
Attribute argAttr = nullptr;
63+
64+
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
65+
if (argAttrs && idx < argAttrs->size())
66+
argAttr = (*argAttrs)[idx];
67+
68+
TypeAttr typeAttr = TypeAttr::get(val.getType());
69+
fields.push_back({fieldName, typeAttr});
70+
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
71+
argAttr);
72+
}
73+
74+
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
75+
FunctionType funcType = funcOp.getFunctionType();
76+
Location loc = funcOp.getLoc();
77+
FuncOp newFuncOp =
78+
rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
79+
80+
rewriter.createBlock(&newFuncOp.getBody());
81+
newFuncOp.getBody().takeBody(funcOp.getBody());
82+
83+
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
84+
std::vector<Value> newArguments;
85+
newArguments.reserve(fields.size());
86+
for (auto &[fieldName, attr] : fields) {
87+
GetFieldOp arg =
88+
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
89+
newArguments.push_back(arg);
90+
}
91+
92+
for (auto [oldArg, newArg] :
93+
llvm::zip(newFuncOp.getArguments(), newArguments)) {
94+
rewriter.replaceAllUsesWith(oldArg, newArg);
95+
}
96+
97+
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
98+
if (failed(newFuncOp.eraseArguments(argsToErase)))
99+
newFuncOp->emitOpError("failed to erase all arguments using BitVector");
100+
101+
rewriter.replaceOp(funcOp, newClassOp);
102+
return success();
103+
}
104+
105+
private:
106+
StringRef attributeName;
107+
};
108+
109+
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
110+
StringRef namedAttribute) {
111+
patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
112+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
2+
3+
module attributes { } {
4+
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
5+
%arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
6+
%arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
7+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
8+
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
9+
%2 = load %1 : <f32>
10+
%3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
11+
%4 = load %3 : <f32>
12+
%5 = add %2, %4 : (f32, f32) -> f32
13+
%6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
14+
assign %5 : f32 to %6 : <f32>
15+
return
16+
}
17+
}
18+
19+
20+
// CHECK: module {
21+
// CHECK-NEXT: emitc.class @modelClass {
22+
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
23+
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
24+
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
25+
// CHECK-NEXT: emitc.func @execute() {
26+
// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32>
27+
// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>
28+
// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32>
29+
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
30+
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
31+
// CHECK-NEXT: load {{.*}} : <f32>
32+
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
33+
// CHECK-NEXT: load {{.*}} : <f32>
34+
// CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32
35+
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
36+
// CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : <f32>
37+
// CHECK-NEXT: return
38+
// CHECK-NEXT: }
39+
// CHECK-NEXT: }
40+
// CHECK-NEXT: }
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
2+
3+
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
4+
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
5+
emitc.return
6+
}
7+
8+
// CHECK: module {
9+
// CHECK-NEXT: emitc.class @fooClass {
10+
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32>
11+
// CHECK-NEXT: emitc.func @execute() {
12+
// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32>
13+
// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
14+
// CHECK-NEXT: return
15+
// CHECK-NEXT: }
16+
// CHECK-NEXT: }
17+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)