Skip to content

Commit 20d9f7a

Browse files
committed
Adding ClassOp, FieldOp, GetFieldOp to allow for a transfrom from func to class
1 parent e7f2084 commit 20d9f7a

File tree

9 files changed

+370
-0
lines changed

9 files changed

+370
-0
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,4 +1572,120 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
15721572
let hasVerifier = 1;
15731573
}
15741574

1575+
def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
1576+
IsolatedFromAbove, OpAsmOpInterface]> {
1577+
let summary =
1578+
"Represents a C++ class definition, encapsulating fields and methods.";
1579+
1580+
let description = [{
1581+
The `emitc.class` operation defines a C++ class, acting as a container
1582+
for its data fields (`emitc.variable`) and methods (`emitc.func`).
1583+
It creates a distinct scope, isolating its contents from the surrounding
1584+
MLIR region, similar to how C++ classes encapsulate their internals.
1585+
1586+
Example:
1587+
```mlir
1588+
emitc.class @MyModelClass {
1589+
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1590+
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
1591+
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1592+
1593+
emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
1594+
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1595+
1596+
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1597+
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1598+
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1599+
1600+
%v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1601+
%v1_val = load %v1 : !emitc.lvalue<f32> -> f32
1602+
1603+
%v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1604+
%v2_val = load %v2 : !emitc.lvalue<f32> -> f32
1605+
1606+
%v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
1607+
1608+
%output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1609+
assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
1610+
1611+
return
1612+
}
1613+
}
1614+
}
1615+
1616+
```
1617+
}];
1618+
1619+
let arguments = (ins SymbolNameAttr:$sym_name);
1620+
1621+
let regions = (region AnyRegion:$body);
1622+
1623+
let builders = [];
1624+
1625+
let extraClassDeclaration = [{
1626+
// Returns the body block containing class members and methods.
1627+
Block &getBlock();
1628+
}];
1629+
1630+
let hasCustomAssemblyFormat = 1;
1631+
1632+
let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
1633+
}
1634+
1635+
def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
1636+
let summary = "A field within a class";
1637+
let description = [{
1638+
The `emitc.field` operation declares a named field within an `emitc.class`
1639+
operation. The field's type must be an EmitC type. An optional initial value can be provided.
1640+
1641+
Example with initial values:
1642+
1643+
```mlir
1644+
emitc.class @MyModelClass {
1645+
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
1646+
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
1647+
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1648+
}
1649+
```
1650+
Example without initial value:
1651+
```mlir
1652+
emitc.class @MyModelClass {
1653+
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1654+
}
1655+
```
1656+
}];
1657+
1658+
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
1659+
OptionalAttr<AnyAttr>:$initial_value);
1660+
1661+
let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
1662+
1663+
let hasVerifier = 1;
1664+
}
1665+
1666+
def EmitC_GetFieldOp
1667+
: EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods<
1668+
SymbolUserOpInterface>]> {
1669+
let summary = "Obtain access to a field within a class instance";
1670+
let description = [{
1671+
The `emitc.get_field` operation retrieves the lvalue of a
1672+
named field from a given class instance.
1673+
1674+
Example:
1675+
1676+
```mlir
1677+
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1678+
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1679+
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1680+
```
1681+
}];
1682+
1683+
let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
1684+
FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
1685+
1686+
let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
1687+
let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
1688+
"type($result) attr-dict";
1689+
}
1690+
15751691
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace mlir {
1515
namespace emitc {
1616

1717
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
18+
#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
1819
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
1920

2021
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,15 @@ def FormExpressionsPass : Pass<"form-expressions"> {
2020
let dependentDialects = ["emitc::EmitCDialect"];
2121
}
2222

23+
def ConvertFuncToClassPass
24+
: Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
25+
let summary = "Convert functions to classes, using arguments as fields.";
26+
let description = [{
27+
This pass transforms `emitc.func` operations into `emitc.class` operations.
28+
Function arguments become fields of the class, and the function body is moved
29+
to a new `execute` method within the class.
30+
}];
31+
let dependentDialects = ["emitc::EmitCDialect"];
32+
}
33+
2334
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
2828
/// Populates `patterns` with expression-related patterns.
2929
void populateExpressionPatterns(RewritePatternSet &patterns);
3030

31+
//===----------------------------------------------------------------------===//
32+
// Convert Func to Class Transform
33+
//===----------------------------------------------------------------------===//
34+
35+
ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
36+
3137
} // namespace emitc
3238
} // namespace mlir
3339

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,79 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
14001400
builder.getNamedAttr("id", builder.getStringAttr(id)));
14011401
}
14021402

1403+
//===----------------------------------------------------------------------===//
1404+
// FieldOp
1405+
//===----------------------------------------------------------------------===//
1406+
LogicalResult FieldOp::verify() {
1407+
if (!isSupportedEmitCType(getType())) {
1408+
return emitOpError("expected valid emitc type");
1409+
}
1410+
1411+
if (!getInitialValue().has_value()) {
1412+
return success();
1413+
}
1414+
1415+
Attribute initValue = getInitialValue().value();
1416+
// Check that the type of the initial value is compatible with the type of
1417+
// the global variable.
1418+
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419+
auto initialValueType = elementsAttr.getType();
1420+
if (!initialValueType) {
1421+
return emitOpError("initial value attribute must have a type");
1422+
}
1423+
auto fieldType = getType();
1424+
if (initialValueType != fieldType) {
1425+
if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
1426+
auto innerFieldType = lvalueType.getValueType();
1427+
if (innerFieldType != initialValueType) {
1428+
return emitOpError("initial value type ")
1429+
<< initialValueType << " is not compatible with field type "
1430+
<< fieldType << " its inner type " << innerFieldType;
1431+
}
1432+
1433+
} else {
1434+
return emitOpError("initial value type ")
1435+
<< initialValueType << " is not compatible with field type "
1436+
<< fieldType;
1437+
}
1438+
}
1439+
}
1440+
1441+
return success();
1442+
}
1443+
1444+
//===----------------------------------------------------------------------===//
1445+
// GetFieldOp
1446+
//===----------------------------------------------------------------------===//
1447+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1448+
auto classNameAttr = getClassNameAttr();
1449+
auto fieldNameAttr = getFieldNameAttr();
1450+
if (!classNameAttr || !fieldNameAttr) {
1451+
return emitError("class and field name attributes are mandatory");
1452+
}
1453+
StringRef className = classNameAttr.getValue();
1454+
StringRef fieldName = fieldNameAttr.getValue();
1455+
1456+
auto fieldOp =
1457+
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
1458+
1459+
if (!fieldOp) {
1460+
return emitOpError("field '")
1461+
<< fieldName << "' not found in class '" << className << "'";
1462+
}
1463+
1464+
Type getFieldResultType = getResult().getType();
1465+
Type fieldType = fieldOp.getType();
1466+
1467+
if (fieldType != getFieldResultType) {
1468+
return emitOpError("result type ")
1469+
<< getFieldResultType << " does not match field '" << fieldName
1470+
<< "' type " << fieldType;
1471+
}
1472+
1473+
return success();
1474+
}
1475+
14031476
//===----------------------------------------------------------------------===//
14041477
// TableGen'd op method definitions
14051478
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
22
Transforms.cpp
33
FormExpressions.cpp
44
TypeConversions.cpp
5+
ConvertFuncToClass.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
10+
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
11+
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
12+
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/IR/IRMapping.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
17+
18+
namespace mlir {
19+
namespace emitc {
20+
21+
#define GEN_PASS_DEF_CONVERTFUNCTOCLASSPASS
22+
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
23+
24+
namespace {
25+
26+
struct ConvertFuncToClassPass
27+
: public impl::ConvertFuncToClassPassBase<ConvertFuncToClassPass> {
28+
void runOnOperation() override {
29+
emitc::FuncOp funcOp = getOperation();
30+
MLIRContext *context = funcOp->getContext();
31+
32+
// Wrap each C operator op with an expression op.
33+
OpBuilder builder(context);
34+
createClass(funcOp, builder);
35+
36+
// // Create the new function inside the class
37+
// auto funcType = FunctionType::get(funcOp.getContext(),
38+
// funcOp.getFunctionType().getInputs(),
39+
// funcOp.getFunctionType().getResults()); auto newFuncOp =
40+
// builder.create<emitc::FuncOp>(
41+
// funcOp.getLoc(),builder.getStringAttr("execute"), funcType );
42+
43+
// builder.createBlock(&newFuncOp.getBody());
44+
// builder.setInsertionPointToStart(&newFuncOp.getBody().front());
45+
46+
// // 7. Remap original arguments to field pointers
47+
// IRMapping mapper;
48+
49+
// // 8. move or clone operations from original function
50+
// for (Operation &opToClone :
51+
// llvm::make_early_inc_range(funcOp.getBody().front())) {
52+
// if (isa<emitc::ConstantOp>(opToClone) ||
53+
// isa<emitc::SubscriptOp>(opToClone) ||
54+
// isa<emitc::LoadOp>(opToClone) ||
55+
// isa<emitc::AddOp>(opToClone) ||
56+
// isa<emitc::AssignOp>(opToClone) ||
57+
// isa<emitc::ReturnOp>(opToClone )) {
58+
// builder.clone(opToClone, mapper);
59+
// } else {
60+
// opToClone.emitOpError("Unsupported operation found");
61+
// }
62+
// }
63+
// if (funcOp->use_empty()) funcOp->erase();
64+
}
65+
};
66+
67+
} // namespace
68+
69+
} // namespace emitc
70+
} // namespace mlir

0 commit comments

Comments
 (0)