Skip to content

Commit f3acc4f

Browse files
committed
rewritten
1 parent 8ba35d7 commit f3acc4f

File tree

10 files changed

+214
-194
lines changed

10 files changed

+214
-194
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,46 +1572,45 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
15721572
let hasVerifier = 1;
15731573
}
15741574

1575-
def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
1576-
IsolatedFromAbove, OpAsmOpInterface]> {
1575+
def EmitC_ClassOp
1576+
: EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
1577+
OpAsmOpInterface, SymbolTable,
1578+
Symbol]#GraphRegionNoTerminator.traits> {
15771579
let summary =
15781580
"Represents a C++ class definition, encapsulating fields and methods.";
15791581

1582+
// FIX WORDING
15801583
let description = [{
15811584
The `emitc.class` operation defines a C++ class, acting as a container
15821585
for its data fields (`emitc.variable`) and methods (`emitc.func`).
15831586
It creates a distinct scope, isolating its contents from the surrounding
15841587
MLIR region, similar to how C++ classes encapsulate their internals.
1588+
All the class memebrs need to be default initalizable.
15851589

15861590
Example:
15871591
```mlir
1588-
emitc.class @MyModelClass {
1589-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1590-
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
1591-
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1592-
1593-
emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
1594-
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1595-
1596-
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1597-
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1598-
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1599-
1600-
%v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1601-
%v1_val = load %v1 : !emitc.lvalue<f32> -> f32
1602-
1603-
%v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1604-
%v2_val = load %v2 : !emitc.lvalue<f32> -> f32
1605-
1606-
%v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
1607-
1608-
%output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
1609-
assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
1610-
1611-
return
1612-
}
1592+
emitc.class @MymainClass {
1593+
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1594+
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1595+
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
1596+
1597+
emitc.func @execute() {
1598+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1599+
1600+
%1 = get_field @another_feature : !emitc.array<1xf32>
1601+
%2 = get_field @some_feature : !emitc.array<1xf32>
1602+
%3 = get_field @output_0 : !emitc.array<1xf32>
1603+
1604+
%4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1605+
%5 = load %4 : <f32>
1606+
%6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1607+
%7 = load %6 : <f32>
1608+
%8 = add %5, %7 : (f32, f32) -> f32
1609+
%9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1610+
assign %8 : f32 to %9 : <f32>
1611+
return
16131612
}
1614-
}
1613+
}
16151614

16161615
```
16171616
}];
@@ -1629,7 +1628,7 @@ def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
16291628

16301629
let hasCustomAssemblyFormat = 1;
16311630

1632-
let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
1631+
let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
16331632
}
16341633

16351634
def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
@@ -1642,23 +1641,18 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
16421641

16431642
```mlir
16441643
emitc.class @MyModelClass {
1645-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
1646-
emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
1647-
emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
1648-
}
1649-
```
1650-
Example without initial value:
1651-
```mlir
1652-
emitc.class @MyModelClass {
1653-
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
1644+
emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
1645+
emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
1646+
emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
16541647
}
16551648
```
16561649
}];
16571650

16581651
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
16591652
OptionalAttr<AnyAttr>:$initial_value);
16601653

1661-
let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
1654+
let assemblyFormat =
1655+
[{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
16621656

16631657
let hasVerifier = 1;
16641658
}
@@ -1674,18 +1668,15 @@ def EmitC_GetFieldOp
16741668
Example:
16751669

16761670
```mlir
1677-
%some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
1678-
%another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
1679-
%output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
1671+
%some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32>
1672+
%another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32>
1673+
%output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32>
16801674
```
16811675
}];
16821676

1683-
let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
1684-
FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
1685-
1686-
let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
1687-
let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
1688-
"type($result) attr-dict";
1677+
let arguments = (ins FlatSymbolRefAttr:$field_name);
1678+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
1679+
let assemblyFormat = "$field_name `:` type($result) attr-dict";
16891680
}
16901681

16911682
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace mlir {
1515
namespace emitc {
1616

1717
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
18-
#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
18+
#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
1919
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
2020

2121
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ def FormExpressionsPass : Pass<"form-expressions"> {
2020
let dependentDialects = ["emitc::EmitCDialect"];
2121
}
2222

23-
def ConvertFuncToClassPass
24-
: Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
25-
let summary = "Convert functions to classes, using arguments as fields.";
23+
def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
24+
let summary = "Wrap functions in classes, using arguments as fields.";
2625
let description = [{
2726
This pass transforms `emitc.func` operations into `emitc.class` operations.
2827
Function arguments become fields of the class, and the function body is moved

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
2828
/// Populates `patterns` with expression-related patterns.
2929
void populateExpressionPatterns(RewritePatternSet &patterns);
3030

31-
//===----------------------------------------------------------------------===//
32-
// Convert Func to Class Transform
33-
//===----------------------------------------------------------------------===//
34-
35-
ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
31+
/// Populates 'patterns' with func-related patterns.
32+
void populateFuncPatterns(RewritePatternSet &patterns);
3633

3734
} // namespace emitc
3835
} // namespace mlir

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,32 +1408,32 @@ LogicalResult FieldOp::verify() {
14081408
return emitOpError("expected valid emitc type");
14091409
}
14101410

1411-
if (!getInitialValue().has_value()) {
1411+
if (!getInitialValue()) {
14121412
return success();
14131413
}
14141414

1415-
Attribute initValue = getInitialValue().value();
1415+
Attribute initValue = *getInitialValue();
14161416
// Check that the type of the initial value is compatible with the type of
14171417
// the global variable.
1418-
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419-
auto initialValueType = elementsAttr.getType();
1418+
if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419+
Type initialValueType = elementsAttr.getType();
14201420
if (!initialValueType) {
14211421
return emitOpError("initial value attribute must have a type");
14221422
}
1423-
auto fieldType = getType();
1423+
Type fieldType = getType();
14241424
if (initialValueType != fieldType) {
1425-
if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
1426-
auto innerFieldType = lvalueType.getValueType();
1425+
if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
1426+
Type innerFieldType = lvalueType.getValueType();
14271427
if (innerFieldType != initialValueType) {
14281428
return emitOpError("initial value type ")
1429-
<< initialValueType << " is not compatible with field type "
1430-
<< fieldType << " its inner type " << innerFieldType;
1429+
<< initialValueType << " is not compatible with field type '"
1430+
<< fieldType << "' its inner type '" << innerFieldType << "'";
14311431
}
14321432

14331433
} else {
1434-
return emitOpError("initial value type ")
1435-
<< initialValueType << " is not compatible with field type "
1436-
<< fieldType;
1434+
return emitOpError("initial value type '")
1435+
<< initialValueType << "' is not compatible with field type '"
1436+
<< fieldType << "'";
14371437
}
14381438
}
14391439
}
@@ -1445,20 +1445,18 @@ LogicalResult FieldOp::verify() {
14451445
// GetFieldOp
14461446
//===----------------------------------------------------------------------===//
14471447
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1448-
auto classNameAttr = getClassNameAttr();
1449-
auto fieldNameAttr = getFieldNameAttr();
1450-
if (!classNameAttr || !fieldNameAttr) {
1451-
return emitError("class and field name attributes are mandatory");
1448+
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1449+
if (!fieldNameAttr) {
1450+
return emitError("field name attribute is mandatory");
14521451
}
1453-
StringRef className = classNameAttr.getValue();
1452+
14541453
StringRef fieldName = fieldNameAttr.getValue();
14551454

1456-
auto fieldOp =
1455+
FieldOp fieldOp =
14571456
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
14581457

14591458
if (!fieldOp) {
1460-
return emitOpError("field '")
1461-
<< fieldName << "' not found in class '" << className << "'";
1459+
return emitOpError("field '") << fieldName << "' not found in the class '";
14621460
}
14631461

14641462
Type getFieldResultType = getResult().getType();

mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
22
Transforms.cpp
33
FormExpressions.cpp
44
TypeConversions.cpp
5-
ConvertFuncToClass.cpp
5+
WrapFuncInClass.cpp
66

77
ADDITIONAL_HEADER_DIRS
88
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms

mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp

Lines changed: 0 additions & 40 deletions
This file was deleted.

mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1111
#include "mlir/IR/IRMapping.h"
1212
#include "mlir/IR/PatternMatch.h"
13-
#include "llvm/Support/Debug.h"
1413

1514
namespace mlir {
1615
namespace emitc {
@@ -43,80 +42,6 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
4342
return expressionOp;
4443
}
4544

46-
ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
47-
builder.setInsertionPoint(funcOp);
48-
49-
auto classOp = builder.create<emitc::ClassOp>(
50-
funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
51-
52-
builder.createBlock(&classOp.getBody());
53-
builder.setInsertionPointToStart(&classOp.getBody().front());
54-
55-
SmallVector<std::pair<StringRef, Type>> fields;
56-
llvm::SmallDenseMap<Value, Value> argToFieldMap;
57-
58-
auto argAttrs = funcOp.getArgAttrs();
59-
if (argAttrs) {
60-
for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
61-
if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
62-
auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
63-
auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
64-
auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
65-
dyn_cast_or_null<emitc::ArrayType>(val.getType())
66-
.getElementType()));
67-
fields.push_back({fieldName.str(), fieldType});
68-
69-
auto typeAttr = TypeAttr::get(val.getType());
70-
mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
71-
auto dictAttr = DictionaryAttr::get(
72-
builder.getContext(),
73-
{builder.getNamedAttr(fieldName.str(), emptyAttr)});
74-
builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
75-
/* attributes*/ dictAttr);
76-
77-
// TODO: From my current understanding, we need to instantiate a class
78-
// so we can get the pointers from .field but we can't do that in here
79-
// so I'm unsure how I can rewrite the following line to ensure
80-
// GetFieldOp works correctly. auto pointer =
81-
// emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
82-
// auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
83-
// pointer, val, "MyModelClass", fieldName);
84-
argToFieldMap[val] = nullptr;
85-
}
86-
}
87-
}
88-
89-
auto funcContext = funcOp.getContext();
90-
auto inputTypes = funcOp.getFunctionType().getInputs();
91-
auto results = funcOp.getFunctionType().getResults();
92-
auto funcType = FunctionType::get(funcContext, inputTypes, results);
93-
auto loc = funcOp.getLoc();
94-
auto newFuncOp = builder.create<emitc::FuncOp>(
95-
loc, builder.getStringAttr("execute"), funcType);
96-
97-
builder.createBlock(&newFuncOp.getBody());
98-
builder.setInsertionPointToStart(&newFuncOp.getBody().front());
99-
100-
IRMapping mapper;
101-
102-
auto body = llvm::make_early_inc_range(funcOp.getBody().front());
103-
for (Operation &opToClone : body) {
104-
if (isa<emitc::ConstantOp>(opToClone) ||
105-
isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
106-
isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
107-
isa<emitc::ReturnOp>(opToClone)) {
108-
builder.clone(opToClone, mapper);
109-
} else {
110-
opToClone.emitOpError("Unsupported operation found");
111-
}
112-
}
113-
114-
// TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
115-
// errors:
116-
117-
return classOp;
118-
}
119-
12045
} // namespace emitc
12146
} // namespace mlir
12247

0 commit comments

Comments
 (0)