diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h index 1af4aa06fa811..259d6c24cd5fc 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h @@ -16,6 +16,7 @@ namespace emitc { #define GEN_PASS_DECL_FORMEXPRESSIONSPASS #define GEN_PASS_DECL_WRAPFUNCINCLASSPASS +#define GEN_PASS_DECL_ADDREFLECTIONMAPPASS #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td index 1893c101e735b..96efbfe3dc3a8 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -53,4 +53,42 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { let dependentDialects = ["emitc::EmitCDialect"]; } +def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> { + let summary = + "Add a reflection map function to EmitC classes for runtime field lookup"; + let description = [{ + This pass adds a `getBufferForName` function to EmitC classes that enables + runtime lookup of field buffers by their string names. + This requires that the class has fields with attributes and a function named `execute`. + The `fieldop` attribute is expected to be a dictionary where: + - The keys are `namedAttribute`. + - The values are arrays containing a single string attribute. + + + Example: + + ```mlir + emitc.class @MyClass { + emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]} + emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]} + emitc.func @execute() { ... } + } + + // becomes: + + emitc.class @MyClass { + emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]} + emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]} + emitc.field @reflectionMap : !emitc.opaque<"const std::map"> = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast(&fieldName0) }, { \22some_feature\22, reinterpret_cast(&fieldName1) } }"> + emitc.func @execute() { ... } + } + ``` + }]; + let dependentDialects = ["mlir::emitc::EmitCDialect"]; + let options = [Option<"namedAttribute", "named-attribute", "std::string", + /*default=*/"", + "Attribute key used to extract field names from fields " + "dictionary attributes">]; +} + #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h index bdf6d0985e6db..7abc430347dc3 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -31,6 +31,10 @@ void populateExpressionPatterns(RewritePatternSet &patterns); /// Populates 'patterns' with func-related patterns. void populateFuncPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns to add reflection map for EmitC classes. +void populateAddReflectionMapPatterns(RewritePatternSet &patterns, + StringRef namedAttribute); + } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp new file mode 100644 index 0000000000000..b0dc84d70b0f6 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp @@ -0,0 +1,136 @@ +//===- AddReflectionMap.cpp - Add a reflection map to a class -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace emitc; + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_ADDREFLECTIONMAPPASS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +namespace { +constexpr const char *mapLibraryHeader = "map"; +constexpr const char *stringLibraryHeader = "string"; + +IncludeOp addHeader(OpBuilder &builder, ModuleOp module, StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return builder.create( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + +class AddReflectionMapPass + : public impl::AddReflectionMapPassBase { + using AddReflectionMapPassBase::AddReflectionMapPassBase; + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateAddReflectionMapPatterns(patterns, namedAttribute); + + walkAndApplyPatterns(module, std::move(patterns)); + bool hasMapHdr = false; + bool hasStringHdr = false; + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast(op); + if (!includeOp) + continue; + if (includeOp.getIsStandardInclude()) { + if (includeOp.getInclude() == mapLibraryHeader) + hasMapHdr = true; + if (includeOp.getInclude() == stringLibraryHeader) + hasStringHdr = true; + } + if (hasMapHdr && hasStringHdr) + return; + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + if (!hasMapHdr) + addHeader(builder, module, mapLibraryHeader); + + if (!hasStringHdr) + addHeader(builder, module, stringLibraryHeader); + } +}; + +} // namespace +} // namespace emitc +} // namespace mlir + +class AddReflectionMapClass : public OpRewritePattern { +public: + AddReflectionMapClass(MLIRContext *context, StringRef attrName) + : OpRewritePattern(context), attributeName(attrName) {} + + LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp, + PatternRewriter &rewriter) const override { + MLIRContext *context = rewriter.getContext(); + + emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get( + context, "const std::map"); + + // Collect all field names + std::vector> fieldNames; + classOp.walk([&](mlir::emitc::FieldOp fieldOp) { + if (ArrayAttr arrayAttr = cast( + fieldOp->getAttrDictionary().get(attributeName))) { + StringAttr stringAttr = cast(arrayAttr[0]); + fieldNames.emplace_back(stringAttr.getValue().str(), + fieldOp.getName().str()); + + } else { + fieldOp.emitError() + << "FieldOp must have a dictionary attribute named '" + << attributeName << "'" + << "with an array containing a string attribute"; + } + }); + + std::string mapString; + mapString += "{ "; + for (size_t i = 0; i < fieldNames.size(); ++i) { + mapString += llvm::formatv( + "{ \"{0}\", reinterpret_cast(&{1}) }{2}", fieldNames[i].first, + fieldNames[i].second, (i < fieldNames.size() - 1) ? ", " : ""); + } + mapString += " }"; + + if (emitc::FuncOp executeFunc = + classOp.lookupSymbol("execute")) + rewriter.setInsertionPoint(executeFunc); + else { + classOp.emitError() << "ClassOp must contain a function named 'execute' " + "to add reflection map"; + return failure(); + } + + rewriter.create( + classOp.getLoc(), rewriter.getStringAttr("reflectionMap"), + TypeAttr::get(mapType), emitc::OpaqueAttr::get(context, mapString)); + return success(); + } + +private: + StringRef attributeName; +}; + +void mlir::emitc::populateAddReflectionMapPatterns(RewritePatternSet &patterns, + StringRef namedAttribute) { + patterns.add(patterns.getContext(), namedAttribute); +} diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt index baf67afc30072..dd8f014dc4737 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIREmitCTransforms FormExpressions.cpp TypeConversions.cpp WrapFuncInClass.cpp + AddReflectionMap.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms diff --git a/mlir/test/Dialect/EmitC/add_reflection_map.mlir b/mlir/test/Dialect/EmitC/add_reflection_map.mlir new file mode 100644 index 0000000000000..ac22580140d58 --- /dev/null +++ b/mlir/test/Dialect/EmitC/add_reflection_map.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt --add-reflection-map="named-attribute=emitc.field_ref" %s | FileCheck %s + +emitc.class @mainClass { + emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]} + emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]} + emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]} + emitc.func @execute() { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = get_field @fieldName0 : !emitc.array<1xf32> + %2 = get_field @fieldName1 : !emitc.array<1xf32> + %3 = get_field @fieldName2 : !emitc.array<1xf32> + %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + %5 = load %4 : + %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + %7 = load %6 : + %8 = add %5, %7 : (f32, f32) -> f32 + %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + assign %8 : f32 to %9 : + return + } +} + +// CHECK: module { +// CHECK-NEXT: emitc.include <"map"> +// CHECK-NEXT: emitc.include <"string"> +// CHECK-NEXT: emitc.class @mainClass { +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]} +// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]} +// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]} +// CHECK-NEXT: emitc.field @reflectionMap : !emitc.opaque<"const std::map"> = +// CHECK-SAME: #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast(&fieldName0) }, +// CHECK-SAME: { \22some_feature\22, reinterpret_cast(&fieldName1) }, +// CHECK-SAME: { \22output_0\22, reinterpret_cast(&fieldName2) } }"> +// CHECK-NEXT: emitc.func @execute() { +// CHECK-NEXT: %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t +// CHECK-NEXT: %1 = get_field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: %2 = get_field @fieldName1 : !emitc.array<1xf32> +// CHECK-NEXT: %3 = get_field @fieldName2 : !emitc.array<1xf32> +// CHECK-NEXT: %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: %5 = load %4 : +// CHECK-NEXT: %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: %7 = load %6 : +// CHECK-NEXT: %8 = add %5, %7 : (f32, f32) -> f32 +// CHECK-NEXT: %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: assign %8 : f32 to %9 : +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +