diff --git a/llvm/lib/Analysis/models/saved-model-to-tflite.py b/llvm/lib/Analysis/models/saved-model-to-tflite.py index 9c83718732945..29fa0dfdcb747 100644 --- a/llvm/lib/Analysis/models/saved-model-to-tflite.py +++ b/llvm/lib/Analysis/models/saved-model-to-tflite.py @@ -31,7 +31,10 @@ def main(argv): src_json = os.path.join(sm_dir, json_file) if tf.io.gfile.exists(src_json): tf.io.gfile.copy(src_json, os.path.join(tfl_dir, json_file)) - + tf.mlir.experimental.tflite_to_tosa_bytecode( + tfl_path, + os.path.join(tfl_dir, 'model.tosa') + ) if __name__ == "__main__": main(sys.argv) diff --git a/llvm/lib/ProfileData/PGOCtxProfReader.cpp b/llvm/lib/ProfileData/PGOCtxProfReader.cpp index ec801d43c8588..5e1be66432f35 100644 --- a/llvm/lib/ProfileData/PGOCtxProfReader.cpp +++ b/llvm/lib/ProfileData/PGOCtxProfReader.cpp @@ -227,8 +227,9 @@ Error PGOCtxProfileReader::loadFlatProfileList(CtxProfFlatProfile &P) { while (canEnterBlockWithID(PGOCtxProfileBlockIDs::FlatProfileBlockID)) { EXPECT_OR_RET(E, readProfile(PGOCtxProfileBlockIDs::FlatProfileBlockID)); auto Guid = E->second.guid(); - if (!P.insert({Guid, std::move(E->second.counters())}).second) - return wrongValue("Duplicate flat profile entries"); + if (!P.insert({Guid, std::move(E->second.counters())}).second) { + errs() << "Duplicate flat profile entries: " << Guid << "\n"; + } } return Error::success(); } diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 10557658d5d7d..9bfd66f611567 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1161,6 +1161,7 @@ def TosaToArithPass : Pass<"tosa-to-arith"> { let summary = "Lower TOSA to the Arith dialect"; let dependentDialects = [ "arith::ArithDialect", + "shape::ShapeDialect" ]; let description = [{ Pass that converts TOSA operations to the equivalent operations using the diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1d1cac8..36b1fed29b8cc 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -222,6 +223,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 9dea12355a519..da4655e41c2cd 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/TosaToArith/TosaToArith.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -33,6 +34,17 @@ class ConstOpConverter : public OpRewritePattern { } }; +class ConstShapeOpConverter : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConstShapeOp op, + PatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, op.getValues()); + return success(); + } + }; + Type matchContainerType(Type element, Type container) { if (auto shapedTy = dyn_cast(container)) return shapedTy.clone(element); @@ -251,6 +263,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern { void mlir::tosa::populateTosaToArithConversionPatterns( RewritePatternSet *patterns) { patterns->add(patterns->getContext()); + patterns->add(patterns->getContext()); } void mlir::tosa::populateTosaRescaleToArithConversionPatterns( diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp index ede3c9e0040fd..ea0a19ca3fcc6 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/TosaToArith/TosaToArith.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 01a7cd7ac94db..23e444d579ce9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -121,7 +121,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() { validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; validationOptions.allowInvalidOpDatatypeCombinations = false; - validationOptions.level = tosa::TosaLevelEnum::EightK; + validationOptions.level = tosa::TosaLevelEnum::None; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, validationOptions); diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index 2108ffd414c56..a6afa6ae22302 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -9,6 +9,8 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Target/Cpp/CppEmitter.h" @@ -45,6 +47,8 @@ void registerToCppTranslation() { // clang-format off registry.insert(); // clang-format on }); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 0c4975a13d301..1ed68a90f0e4b 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -6,9 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir-c/BuiltinAttributes.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -1138,8 +1141,46 @@ static LogicalResult printOperation(CppEmitter &emitter, "with multiple blocks needs variables declared at top"); } - CppEmitter::Scope scope(emitter); + CppEmitter::Scope classScope(emitter); raw_indented_ostream &os = emitter.ostream(); + os << "class MyClass final {\n"; + auto argAttrs = functionOp.getArgAttrs(); + + std::map fields; + if (argAttrs) + for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) { + if (auto da = dyn_cast(a)) { + auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); + auto name = cast(cast(nv)[0]).str(); + fields[name] = v; + os << " "; + if (failed(emitter.emitType(functionOp.getLoc(), v.getType()))) + return failure(); + os << " " << emitter.getOrCreateName(v) << ";\n"; + } + } + + for (auto & r : functionOp->getRegions()) + for (auto &b : r.getBlocks()) + for (auto &opt : b.getOperations()) + if (auto alloc = dyn_cast(opt)) { + auto name = emitter.getOrCreateName(alloc).str(); + fields[name] = alloc; + if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType()))) + return failure(); + os << " [" << alloc.getType().getNumElements() <<"] "; + os << " " << name << ";\n"; + } + os << " std::map _buffer_map {"; + for (auto &[n,v]:fields) + os << "{ \"" << n << "\"" << ", reinterpret_cast(" << emitter.getOrCreateName(v) << ") },"; + os << " };\n"; + os << " char* getBufferForName(const std::string& name) const {\n"; + os << " auto it = _buffer_map.find(name);\n"; + os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n"; + os << " }\n"; + CppEmitter::Scope scope(emitter); + if (functionOp.getSpecifiers()) { for (Attribute specifier : functionOp.getSpecifiersAttr()) { os << cast(specifier).str() << " "; @@ -1160,13 +1201,13 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ");"; return success(); } - if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) - return failure(); + // if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) + // return failure(); os << ") {\n"; if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) return failure(); os << "}\n"; - + os << "};\n"; return success(); }