Skip to content

Commit 5610f24

Browse files
committed
rewritten
1 parent 764ffad commit 5610f24

File tree

10 files changed

+214
-194
lines changed

10 files changed

+214
-194
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,46 +1593,45 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
15931593
let hasVerifier = 1;
15941594
}
15951595

1596-
def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
1597-
IsolatedFromAbove, OpAsmOpInterface]> {
1596+
def EmitC_ClassOp
1597+
: EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
1598+
OpAsmOpInterface, SymbolTable,
1599+
Symbol]#GraphRegionNoTerminator.traits> {
15981600
let summary =
15991601
"Represents a C++ class definition, encapsulating fields and methods.";
16001602

1603+
// FIX WORDING
16011604
let description = [{
16021605
The `emitc.class` operation defines a C++ class, acting as a container
16031606
for its data fields (`emitc.variable`) and methods (`emitc.func`).
16041607
It creates a distinct scope, isolating its contents from the surrounding
16051608
MLIR region, similar to how C++ classes encapsulate their internals.
1609+
All the class memebrs need to be default initalizable.
16061610

16071611
Example:
16081612
```mlir
1609-
emitc.class @MyModelClass {
1610-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1611-
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
1612-
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1613-
1614-
emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
1615-
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1616-
1617-
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1618-
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1619-
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1620-
1621-
%v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1622-
%v1_val = load %v1 : !emitc.lvalue<f32> -> f32
1623-
1624-
%v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1625-
%v2_val = load %v2 : !emitc.lvalue<f32> -> f32
1626-
1627-
%v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
1628-
1629-
%output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1630-
assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
1631-
1632-
return
1633-
}
1613+
emitc.class @MymainClass {
1614+
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1615+
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1616+
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
1617+
1618+
emitc.func @execute() {
1619+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1620+
1621+
%1 = get_field @another_feature : !emitc.array<1xf32>
1622+
%2 = get_field @some_feature : !emitc.array<1xf32>
1623+
%3 = get_field @output_0 : !emitc.array<1xf32>
1624+
1625+
%4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1626+
%5 = load %4 : <f32>
1627+
%6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1628+
%7 = load %6 : <f32>
1629+
%8 = add %5, %7 : (f32, f32) -> f32
1630+
%9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1631+
assign %8 : f32 to %9 : <f32>
1632+
return
16341633
}
1635-
}
1634+
}
16361635

16371636
```
16381637
}];
@@ -1650,7 +1649,7 @@ def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
16501649

16511650
let hasCustomAssemblyFormat = 1;
16521651

1653-
let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
1652+
let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
16541653
}
16551654

16561655
def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
@@ -1663,23 +1662,18 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
16631662

16641663
```mlir
16651664
emitc.class @MyModelClass {
1666-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
1667-
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
1668-
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1669-
}
1670-
```
1671-
Example without initial value:
1672-
```mlir
1673-
emitc.class @MyModelClass {
1674-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1665+
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1666+
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1667+
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
16751668
}
16761669
```
16771670
}];
16781671

16791672
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
16801673
OptionalAttr<AnyAttr>:$initial_value);
16811674

1682-
let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
1675+
let assemblyFormat =
1676+
[{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
16831677

16841678
let hasVerifier = 1;
16851679
}
@@ -1695,18 +1689,15 @@ def EmitC_GetFieldOp
16951689
Example:
16961690

16971691
```mlir
1698-
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1699-
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1700-
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1692+
%some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32>
1693+
%another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32>
1694+
%output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32>
17011695
```
17021696
}];
17031697

1704-
let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
1705-
FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
1706-
1707-
let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
1708-
let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
1709-
"type($result) attr-dict";
1698+
let arguments = (ins FlatSymbolRefAttr:$field_name);
1699+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
1700+
let assemblyFormat = "$field_name `:` type($result) attr-dict";
17101701
}
17111702

17121703
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace mlir {
1515
namespace emitc {
1616

1717
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
18-
#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
18+
#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
1919
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
2020

2121
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ def FormExpressionsPass : Pass<"form-expressions"> {
2020
let dependentDialects = ["emitc::EmitCDialect"];
2121
}
2222

23-
def ConvertFuncToClassPass
24-
: Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
25-
let summary = "Convert functions to classes, using arguments as fields.";
23+
def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
24+
let summary = "Wrap functions in classes, using arguments as fields.";
2625
let description = [{
2726
This pass transforms `emitc.func` operations into `emitc.class` operations.
2827
Function arguments become fields of the class, and the function body is moved

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
2828
/// Populates `patterns` with expression-related patterns.
2929
void populateExpressionPatterns(RewritePatternSet &patterns);
3030

31-
//===----------------------------------------------------------------------===//
32-
// Convert Func to Class Transform
33-
//===----------------------------------------------------------------------===//
34-
35-
ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
31+
/// Populates 'patterns' with func-related patterns.
32+
void populateFuncPatterns(RewritePatternSet &patterns);
3633

3734
} // namespace emitc
3835
} // namespace mlir

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,32 +1408,32 @@ LogicalResult FieldOp::verify() {
14081408
return emitOpError("expected valid emitc type");
14091409
}
14101410

1411-
if (!getInitialValue().has_value()) {
1411+
if (!getInitialValue()) {
14121412
return success();
14131413
}
14141414

1415-
Attribute initValue = getInitialValue().value();
1415+
Attribute initValue = *getInitialValue();
14161416
// Check that the type of the initial value is compatible with the type of
14171417
// the global variable.
1418-
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419-
auto initialValueType = elementsAttr.getType();
1418+
if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419+
Type initialValueType = elementsAttr.getType();
14201420
if (!initialValueType) {
14211421
return emitOpError("initial value attribute must have a type");
14221422
}
1423-
auto fieldType = getType();
1423+
Type fieldType = getType();
14241424
if (initialValueType != fieldType) {
1425-
if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
1426-
auto innerFieldType = lvalueType.getValueType();
1425+
if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
1426+
Type innerFieldType = lvalueType.getValueType();
14271427
if (innerFieldType != initialValueType) {
14281428
return emitOpError("initial value type ")
1429-
<< initialValueType << " is not compatible with field type "
1430-
<< fieldType << " its inner type " << innerFieldType;
1429+
<< initialValueType << " is not compatible with field type '"
1430+
<< fieldType << "' its inner type '" << innerFieldType << "'";
14311431
}
14321432

14331433
} else {
1434-
return emitOpError("initial value type ")
1435-
<< initialValueType << " is not compatible with field type "
1436-
<< fieldType;
1434+
return emitOpError("initial value type '")
1435+
<< initialValueType << "' is not compatible with field type '"
1436+
<< fieldType << "'";
14371437
}
14381438
}
14391439
}
@@ -1445,20 +1445,18 @@ LogicalResult FieldOp::verify() {
14451445
// GetFieldOp
14461446
//===----------------------------------------------------------------------===//
14471447
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1448-
auto classNameAttr = getClassNameAttr();
1449-
auto fieldNameAttr = getFieldNameAttr();
1450-
if (!classNameAttr || !fieldNameAttr) {
1451-
return emitError("class and field name attributes are mandatory");
1448+
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1449+
if (!fieldNameAttr) {
1450+
return emitError("field name attribute is mandatory");
14521451
}
1453-
StringRef className = classNameAttr.getValue();
1452+
14541453
StringRef fieldName = fieldNameAttr.getValue();
14551454

1456-
auto fieldOp =
1455+
FieldOp fieldOp =
14571456
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
14581457

14591458
if (!fieldOp) {
1460-
return emitOpError("field '")
1461-
<< fieldName << "' not found in class '" << className << "'";
1459+
return emitOpError("field '") << fieldName << "' not found in the class '";
14621460
}
14631461

14641462
Type getFieldResultType = getResult().getType();

mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
22
Transforms.cpp
33
FormExpressions.cpp
44
TypeConversions.cpp
5-
ConvertFuncToClass.cpp
5+
WrapFuncInClass.cpp
66

77
ADDITIONAL_HEADER_DIRS
88
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms

mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1111
#include "mlir/IR/IRMapping.h"
1212
#include "mlir/IR/PatternMatch.h"
13-
#include "llvm/Support/Debug.h"
1413

1514
namespace mlir {
1615
namespace emitc {
@@ -42,80 +41,6 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
4241
return expressionOp;
4342
}
4443

45-
ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
46-
builder.setInsertionPoint(funcOp);
47-
48-
auto classOp = builder.create<emitc::ClassOp>(
49-
funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
50-
51-
builder.createBlock(&classOp.getBody());
52-
builder.setInsertionPointToStart(&classOp.getBody().front());
53-
54-
SmallVector<std::pair<StringRef, Type>> fields;
55-
llvm::SmallDenseMap<Value, Value> argToFieldMap;
56-
57-
auto argAttrs = funcOp.getArgAttrs();
58-
if (argAttrs) {
59-
for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
60-
if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
61-
auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
62-
auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
63-
auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
64-
dyn_cast_or_null<emitc::ArrayType>(val.getType())
65-
.getElementType()));
66-
fields.push_back({fieldName.str(), fieldType});
67-
68-
auto typeAttr = TypeAttr::get(val.getType());
69-
mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
70-
auto dictAttr = DictionaryAttr::get(
71-
builder.getContext(),
72-
{builder.getNamedAttr(fieldName.str(), emptyAttr)});
73-
builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
74-
/* attributes*/ dictAttr);
75-
76-
// TODO: From my current understanding, we need to instantiate a class
77-
// so we can get the pointers from .field but we can't do that in here
78-
// so I'm unsure how I can rewrite the following line to ensure
79-
// GetFieldOp works correctly. auto pointer =
80-
// emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
81-
// auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
82-
// pointer, val, "MyModelClass", fieldName);
83-
argToFieldMap[val] = nullptr;
84-
}
85-
}
86-
}
87-
88-
auto funcContext = funcOp.getContext();
89-
auto inputTypes = funcOp.getFunctionType().getInputs();
90-
auto results = funcOp.getFunctionType().getResults();
91-
auto funcType = FunctionType::get(funcContext, inputTypes, results);
92-
auto loc = funcOp.getLoc();
93-
auto newFuncOp = builder.create<emitc::FuncOp>(
94-
loc, builder.getStringAttr("execute"), funcType);
95-
96-
builder.createBlock(&newFuncOp.getBody());
97-
builder.setInsertionPointToStart(&newFuncOp.getBody().front());
98-
99-
IRMapping mapper;
100-
101-
auto body = llvm::make_early_inc_range(funcOp.getBody().front());
102-
for (Operation &opToClone : body) {
103-
if (isa<emitc::ConstantOp>(opToClone) ||
104-
isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
105-
isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
106-
isa<emitc::ReturnOp>(opToClone)) {
107-
builder.clone(opToClone, mapper);
108-
} else {
109-
opToClone.emitOpError("Unsupported operation found");
110-
}
111-
}
112-
113-
// TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
114-
// errors:
115-
116-
return classOp;
117-
}
118-
11944
} // namespace emitc
12045
} // namespace mlir
12146

0 commit comments

Comments
 (0)