Skip to content

Commit 0b5aea7

Browse files
committed
Avoid unnecessary checks
1 parent 9c97014 commit 0b5aea7

File tree

6 files changed

+82
-130
lines changed

6 files changed

+82
-130
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,18 +1607,13 @@ def EmitC_ClassOp
16071607
MLIR region, similar to how C++ classes encapsulate their internals.
16081608

16091609
Example:
1610+
16101611
```mlir
1611-
emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}) attributes { } {
1612-
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1613-
%1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
1614-
return
1615-
}
1616-
// becomes
16171612
emitc.class @modelClass {
1618-
emitc.field @input_tensor : !emitc.array<1xf32> = {emitc.opaque = ["input_tensor"]}
1613+
emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"}
16191614
emitc.func @execute() {
16201615
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
1621-
%1 = get_field @input_tensor : !emitc.array<1xf32>
1616+
%1 = get_field @fieldName0 : !emitc.array<1xf32>
16221617
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
16231618
return
16241619
}
@@ -1630,8 +1625,6 @@ def EmitC_ClassOp
16301625

16311626
let regions = (region AnyRegion:$body);
16321627

1633-
let builders = [];
1634-
16351628
let extraClassDeclaration = [{
16361629
// Returns the body block containing class members and methods.
16371630
Block &getBlock();
@@ -1647,29 +1640,21 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
16471640
let description = [{
16481641
The `emitc.field` operation declares a named field within an `emitc.class`
16491642
operation. The field's type must be an EmitC type.
1650-
If the corresponding function argument has attributes (accessed via `argAttrs`),
1651-
these attributes are attached to the field operation.
1652-
Otherwise, the field is created without additional attributes.
16531643

1654-
Example of func argument with attributes:
1655-
```mlir
1656-
%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}
1657-
// becomes
1658-
emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
1659-
```
1660-
Example of func argument without attributes:
1644+
Example:
1645+
16611646
```mlir
1662-
%arg0 : !emitc.array<1xf32>
1663-
// becomes
1647+
// Example with an attribute:
1648+
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
1649+
// Example with no attribute:
16641650
emitc.field @fieldName0 : !emitc.array<1xf32>
16651651
```
16661652
}];
16671653

16681654
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
1669-
OptionalAttr<AnyAttr>:$initial_value);
1655+
OptionalAttr<AnyAttr>:$attrs);
16701656

1671-
let assemblyFormat =
1672-
[{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
1657+
let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
16731658

16741659
let hasVerifier = 1;
16751660
}

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,29 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
2626
This pass transforms `emitc.func` operations into `emitc.class` operations.
2727
Function arguments become fields of the class, and the function body is moved
2828
to a new `execute` method within the class.
29+
If the corresponding function argument has attributes (accessed via `argAttrs`),
30+
these attributes are attached to the field operation.
31+
Otherwise, the field is created without additional attributes.
32+
33+
Example:
34+
35+
```mlir
36+
emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } {
37+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
38+
%1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
39+
return
40+
}
41+
// becomes
42+
emitc.class @modelClass {
43+
emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
44+
emitc.func @execute() {
45+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
46+
%1 = get_field @input_tensor : !emitc.array<1xf32>
47+
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
48+
return
49+
}
50+
}
51+
```
2952
}];
3053
let dependentDialects = ["emitc::EmitCDialect"];
3154
let options = [Option<

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
3030

3131
/// Populates 'patterns' with func-related patterns.
3232
void populateFuncPatterns(RewritePatternSet &patterns,
33-
const std::string &namedAttribute);
33+
StringRef namedAttribute);
3434

3535
} // namespace emitc
3636
} // namespace mlir

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,39 +1404,19 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
14041404
// FieldOp
14051405
//===----------------------------------------------------------------------===//
14061406
LogicalResult FieldOp::verify() {
1407-
if (!isSupportedEmitCType(getType())) {
1407+
if (!isSupportedEmitCType(getType()))
14081408
return emitOpError("expected valid emitc type");
1409-
}
14101409

1411-
if (!getInitialValue()) {
1412-
return success();
1413-
}
1410+
Operation *parentOp = getOperation()->getParentOp();
1411+
if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1412+
return emitOpError("field must be nested within an emitc.class operation");
14141413

1415-
Attribute initValue = *getInitialValue();
1416-
// Check that the type of the initial value is compatible with the type of
1417-
// the global variable.
1418-
if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1419-
Type initialValueType = elementsAttr.getType();
1420-
if (!initialValueType) {
1421-
return emitOpError("initial value attribute must have a type");
1422-
}
1423-
Type fieldType = getType();
1424-
if (initialValueType != fieldType) {
1425-
if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
1426-
Type innerFieldType = lvalueType.getValueType();
1427-
if (innerFieldType != initialValueType) {
1428-
return emitOpError("initial value type ")
1429-
<< initialValueType << " is not compatible with field type '"
1430-
<< fieldType << "' its inner type '" << innerFieldType << "'";
1431-
}
1432-
1433-
} else {
1434-
return emitOpError("initial value type '")
1435-
<< initialValueType << "' is not compatible with field type '"
1436-
<< fieldType << "'";
1437-
}
1438-
}
1439-
}
1414+
StringAttr symName = getSymNameAttr();
1415+
if (!symName || symName.getValue().empty())
1416+
return emitOpError("field must have a non-empty symbol name");
1417+
1418+
if (!getAttrs())
1419+
return success();
14401420

14411421
return success();
14421422
}
@@ -1446,27 +1426,19 @@ LogicalResult FieldOp::verify() {
14461426
//===----------------------------------------------------------------------===//
14471427
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
14481428
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1449-
if (!fieldNameAttr) {
1450-
return emitError("field name attribute is mandatory");
1451-
}
1452-
1453-
StringRef fieldName = fieldNameAttr.getValue();
1454-
14551429
FieldOp fieldOp =
1456-
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
1457-
1458-
if (!fieldOp) {
1459-
return emitOpError("field '") << fieldName << "' not found in the class '";
1460-
}
1430+
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
1431+
if (!fieldOp)
1432+
return emitOpError("field '")
1433+
<< fieldNameAttr << "' not found in the class";
14611434

14621435
Type getFieldResultType = getResult().getType();
14631436
Type fieldType = fieldOp.getType();
14641437

1465-
if (fieldType != getFieldResultType) {
1438+
if (fieldType != getFieldResultType)
14661439
return emitOpError("result type ")
1467-
<< getFieldResultType << " does not match field '" << fieldName
1440+
<< getFieldResultType << " does not match field '" << fieldNameAttr
14681441
<< "' type " << fieldType;
1469-
}
14701442

14711443
return success();
14721444
}
Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,56 @@
1-
//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
1+
//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir-c/Rewrite.h"
109
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1110
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
1211
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
1312
#include "mlir/IR/Attributes.h"
1413
#include "mlir/IR/Builders.h"
1514
#include "mlir/IR/BuiltinAttributes.h"
1615
#include "mlir/IR/PatternMatch.h"
17-
#include "mlir/IR/TypeRange.h"
18-
#include "mlir/IR/Value.h"
19-
#include "mlir/Pass/Pass.h"
20-
#include "mlir/Transforms/DialectConversion.h"
21-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22-
#include "llvm/ADT/StringRef.h"
23-
#include "llvm/Support/GraphWriter.h"
24-
#include "llvm/Support/LogicalResult.h"
25-
#include <string>
16+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
17+
18+
using namespace mlir;
19+
using namespace emitc;
2620

2721
namespace mlir {
2822
namespace emitc {
29-
3023
#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
3124
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
3225

3326
namespace {
34-
3527
struct WrapFuncInClassPass
3628
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
3729
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
3830
void runOnOperation() override {
3931
Operation *rootOp = getOperation();
40-
MLIRContext *context = rootOp->getContext();
4132

42-
RewritePatternSet patterns(context);
33+
RewritePatternSet patterns(&getContext());
4334
populateFuncPatterns(patterns, namedAttribute);
4435

45-
if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
46-
return signalPassFailure();
47-
}
48-
void getDependentDialects(DialectRegistry &registry) const override {
49-
registry.insert<emitc::EmitCDialect>();
36+
walkAndApplyPatterns(rootOp, std::move(patterns));
5037
}
5138
};
5239

5340
} // namespace
54-
5541
} // namespace emitc
5642
} // namespace mlir
5743

58-
using namespace mlir;
59-
using namespace mlir::emitc;
60-
6144
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
62-
private:
63-
std::string attributeName;
64-
6545
public:
66-
WrapFuncInClass(MLIRContext *context, const std::string &attrName)
46+
WrapFuncInClass(MLIRContext *context, StringRef attrName)
6747
: OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
6848

6949
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
7050
PatternRewriter &rewriter) const override {
71-
if (funcOp->getParentOfType<emitc::ClassOp>()) {
72-
return failure();
73-
}
51+
7452
auto className = funcOp.getSymNameAttr().str() + "Class";
75-
mlir::emitc::ClassOp newClassOp =
76-
rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
53+
ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
7754

7855
SmallVector<std::pair<StringAttr, TypeAttr>> fields;
7956
rewriter.createBlock(&newClassOp.getBody());
@@ -84,19 +61,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
8461
StringAttr fieldName;
8562
Attribute argAttr = nullptr;
8663

64+
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
8765
if (argAttrs && idx < argAttrs->size()) {
8866
if (DictionaryAttr dictAttr =
89-
dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
90-
if (auto namedAttr = dictAttr.getNamed(attributeName)) {
91-
Attribute nv = namedAttr->getValue();
92-
fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
93-
argAttr = (*argAttrs)[idx];
94-
}
95-
}
96-
}
97-
98-
if (!fieldName) {
99-
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
67+
dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
68+
argAttr = (*argAttrs)[idx];
10069
}
10170

10271
TypeAttr typeAttr = TypeAttr::get(val.getType());
@@ -106,19 +75,17 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
10675
}
10776

10877
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
109-
MLIRContext *funcContext = funcOp.getContext();
110-
ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
111-
ArrayRef<Type> results = funcOp.getFunctionType().getResults();
112-
FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
78+
FunctionType funcType = funcOp.getFunctionType();
11379
Location loc = funcOp.getLoc();
114-
FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
115-
loc, rewriter.getStringAttr("execute"), funcType);
80+
FuncOp newFuncOp =
81+
rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
11682

11783
rewriter.createBlock(&newFuncOp.getBody());
11884
newFuncOp.getBody().takeBody(funcOp.getBody());
11985

12086
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
12187
std::vector<Value> newArguments;
88+
newArguments.reserve(fields.size());
12289
for (auto &[fieldName, attr] : fields) {
12390
GetFieldOp arg =
12491
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
@@ -132,15 +99,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
13299

133100
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
134101
if (failed(newFuncOp.eraseArguments(argsToErase))) {
135-
newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
102+
newFuncOp->emitOpError("failed to erase all arguments using BitVector");
136103
}
137104

138105
rewriter.replaceOp(funcOp, newClassOp);
139106
return success();
140107
}
108+
109+
private:
110+
StringRef attributeName;
141111
};
142112

143113
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
144-
const std::string &namedAttribute) {
114+
StringRef namedAttribute) {
145115
patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
146116
}

mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
22

33
module attributes { } {
4-
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
4+
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
5+
%arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
6+
%arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
57
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
68
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
79
%2 = load %1 : <f32>
@@ -17,14 +19,14 @@ module attributes { } {
1719

1820
// CHECK: module {
1921
// CHECK-NEXT: emitc.class @modelClass {
20-
// CHECK-NEXT: emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
21-
// CHECK-NEXT: emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
22-
// CHECK-NEXT: emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
22+
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
23+
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
24+
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
2325
// CHECK-NEXT: emitc.func @execute() {
26+
// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32>
27+
// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>
28+
// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32>
2429
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
25-
// CHECK-NEXT: get_field @another_feature : !emitc.array<1xf32>
26-
// CHECK-NEXT: get_field @some_feature : !emitc.array<1xf32>
27-
// CHECK-NEXT: get_field @output_0 : !emitc.array<1xf32>
2830
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
2931
// CHECK-NEXT: load {{.*}} : <f32>
3032
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>

0 commit comments

Comments
 (0)