Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace emitc {

#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
#define GEN_PASS_DECL_ADDREFLECTIONMAPPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,42 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
let dependentDialects = ["emitc::EmitCDialect"];
}

def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> {
let summary =
"Add a reflection map function to EmitC classes for runtime field lookup";
let description = [{
This pass adds a `getBufferForName` function to EmitC classes that enables
runtime lookup of field buffers by their string names.
This requires that the class has fields with attributes and a function named `execute`.
The `fieldop` attribute is expected to be a dictionary where:
- The keys are `namedAttribute`.
- The values are arrays containing a single string attribute.


Example:

```mlir
emitc.class @MyClass {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.func @execute() { ... }
}

// becomes:

emitc.class @MyClass {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, char*>"> = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) } }">
emitc.func @execute() { ... }
}
```
}];
let dependentDialects = ["mlir::emitc::EmitCDialect"];
let options = [Option<"namedAttribute", "named-attribute", "std::string",
/*default=*/"",
"Attribute key used to extract field names from fields "
"dictionary attributes">];
}

#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
/// Populates 'patterns' with func-related patterns.
void populateFuncPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns to add reflection map for EmitC classes.
void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
StringRef namedAttribute);

} // namespace emitc
} // namespace mlir

Expand Down
136 changes: 136 additions & 0 deletions mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//===- AddReflectionMap.cpp - Add a reflection map to a class -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace emitc;

namespace mlir {
namespace emitc {
#define GEN_PASS_DEF_ADDREFLECTIONMAPPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"

namespace {
constexpr const char *mapLibraryHeader = "map";
constexpr const char *stringLibraryHeader = "string";

IncludeOp addHeader(OpBuilder &builder, ModuleOp module, StringRef headerName) {
StringAttr includeAttr = builder.getStringAttr(headerName);
return builder.create<emitc::IncludeOp>(
module.getLoc(), includeAttr,
/*is_standard_include=*/builder.getUnitAttr());
}

class AddReflectionMapPass
: public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
using AddReflectionMapPassBase::AddReflectionMapPassBase;
void runOnOperation() override {
mlir::ModuleOp module = getOperation();

RewritePatternSet patterns(&getContext());
populateAddReflectionMapPatterns(patterns, namedAttribute);

walkAndApplyPatterns(module, std::move(patterns));
bool hasMapHdr = false;
bool hasStringHdr = false;
for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp)
continue;
if (includeOp.getIsStandardInclude()) {
if (includeOp.getInclude() == mapLibraryHeader)
hasMapHdr = true;
if (includeOp.getInclude() == stringLibraryHeader)
hasStringHdr = true;
}
if (hasMapHdr && hasStringHdr)
return;
}

mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
if (!hasMapHdr)
addHeader(builder, module, mapLibraryHeader);

if (!hasStringHdr)
addHeader(builder, module, stringLibraryHeader);
}
};

} // namespace
} // namespace emitc
} // namespace mlir

class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
public:
AddReflectionMapClass(MLIRContext *context, StringRef attrName)
: OpRewritePattern<emitc::ClassOp>(context), attributeName(attrName) {}

LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
PatternRewriter &rewriter) const override {
MLIRContext *context = rewriter.getContext();

emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
context, "const std::map<std::string, char*>");

// Collect all field names
std::vector<std::pair<std::string, std::string>> fieldNames;
classOp.walk([&](mlir::emitc::FieldOp fieldOp) {
if (ArrayAttr arrayAttr = cast<mlir::ArrayAttr>(
fieldOp->getAttrDictionary().get(attributeName))) {
StringAttr stringAttr = cast<mlir::StringAttr>(arrayAttr[0]);
fieldNames.emplace_back(stringAttr.getValue().str(),
fieldOp.getName().str());

} else {
fieldOp.emitError()
<< "FieldOp must have a dictionary attribute named '"
<< attributeName << "'"
<< "with an array containing a string attribute";
}
});

std::string mapString;
mapString += "{ ";
for (size_t i = 0; i < fieldNames.size(); ++i) {
mapString += llvm::formatv(
"{ \"{0}\", reinterpret_cast<char*>(&{1}) }{2}", fieldNames[i].first,
fieldNames[i].second, (i < fieldNames.size() - 1) ? ", " : "");
}
mapString += " }";

if (emitc::FuncOp executeFunc =
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute"))
rewriter.setInsertionPoint(executeFunc);
else {
classOp.emitError() << "ClassOp must contain a function named 'execute' "
"to add reflection map";
return failure();
}

rewriter.create<emitc::FieldOp>(
classOp.getLoc(), rewriter.getStringAttr("reflectionMap"),
TypeAttr::get(mapType), emitc::OpaqueAttr::get(context, mapString));
return success();
}

private:
StringRef attributeName;
};

void mlir::emitc::populateAddReflectionMapPatterns(RewritePatternSet &patterns,
StringRef namedAttribute) {
patterns.add<AddReflectionMapClass>(patterns.getContext(), namedAttribute);
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
FormExpressions.cpp
TypeConversions.cpp
WrapFuncInClass.cpp
AddReflectionMap.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/EmitC/add_reflection_map.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: mlir-opt --add-reflection-map="named-attribute=emitc.field_ref" %s | FileCheck %s

emitc.class @mainClass {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
emitc.func @execute() {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = get_field @fieldName0 : !emitc.array<1xf32>
%2 = get_field @fieldName1 : !emitc.array<1xf32>
%3 = get_field @fieldName2 : !emitc.array<1xf32>
%4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
%5 = load %4 : <f32>
%6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
%7 = load %6 : <f32>
%8 = add %5, %7 : (f32, f32) -> f32
%9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
assign %8 : f32 to %9 : <f32>
return
}
}

// CHECK: module {
// CHECK-NEXT: emitc.include <"map">
// CHECK-NEXT: emitc.include <"string">
// CHECK-NEXT: emitc.class @mainClass {
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
// CHECK-NEXT: emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, char*>"> =
// CHECK-SAME: #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) },
// CHECK-SAME: { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) },
// CHECK-SAME: { \22output_0\22, reinterpret_cast<char*>(&fieldName2) } }">
// CHECK-NEXT: emitc.func @execute() {
// CHECK-NEXT: %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
// CHECK-NEXT: %1 = get_field @fieldName0 : !emitc.array<1xf32>
// CHECK-NEXT: %2 = get_field @fieldName1 : !emitc.array<1xf32>
// CHECK-NEXT: %3 = get_field @fieldName2 : !emitc.array<1xf32>
// CHECK-NEXT: %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
// CHECK-NEXT: %5 = load %4 : <f32>
// CHECK-NEXT: %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
// CHECK-NEXT: %7 = load %6 : <f32>
// CHECK-NEXT: %8 = add %5, %7 : (f32, f32) -> f32
// CHECK-NEXT: %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
// CHECK-NEXT: assign %8 : f32 to %9 : <f32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }