Skip to content
5 changes: 4 additions & 1 deletion llvm/lib/Analysis/models/saved-model-to-tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def main(argv):
src_json = os.path.join(sm_dir, json_file)
if tf.io.gfile.exists(src_json):
tf.io.gfile.copy(src_json, os.path.join(tfl_dir, json_file))

tf.mlir.experimental.tflite_to_tosa_bytecode(
tfl_path,
os.path.join(tfl_dir, 'model.tosa')
)

if __name__ == "__main__":
main(sys.argv)
5 changes: 3 additions & 2 deletions llvm/lib/ProfileData/PGOCtxProfReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ Error PGOCtxProfileReader::loadFlatProfileList(CtxProfFlatProfile &P) {
while (canEnterBlockWithID(PGOCtxProfileBlockIDs::FlatProfileBlockID)) {
EXPECT_OR_RET(E, readProfile(PGOCtxProfileBlockIDs::FlatProfileBlockID));
auto Guid = E->second.guid();
if (!P.insert({Guid, std::move(E->second.counters())}).second)
return wrongValue("Duplicate flat profile entries");
if (!P.insert({Guid, std::move(E->second.counters())}).second) {
errs() << "Duplicate flat profile entries: " << Guid << "\n";
}
}
return Error::success();
}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,7 @@ def TosaToArithPass : Pass<"tosa-to-arith"> {
let summary = "Lower TOSA to the Arith dialect";
let dependentDialects = [
"arith::ArithDialect",
"shape::ShapeDialect"
];
let description = [{
Pass that converts TOSA operations to the equivalent operations using the
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -222,6 +223,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {

void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal,
ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
13 changes: 13 additions & 0 deletions mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -33,6 +34,17 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
}
};

class ConstShapeOpConverter : public OpRewritePattern<tosa::ConstShapeOp> {
public:
using OpRewritePattern<tosa::ConstShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::ConstShapeOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<shape::ConstShapeOp>(op, op.getValues());
return success();
}
};

Type matchContainerType(Type element, Type container) {
if (auto shapedTy = dyn_cast<ShapedType>(container))
return shapedTy.clone(element);
Expand Down Expand Up @@ -251,6 +263,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
void mlir::tosa::populateTosaToArithConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<ConstOpConverter>(patterns->getContext());
patterns->add<ConstShapeOpConverter>(patterns->getContext());
}

void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Conversion/TosaToArith/TosaToArith.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
validationOptions.level = tosa::TosaLevelEnum::EightK;
validationOptions.level = tosa::TosaLevelEnum::None;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Target/Cpp/TranslateRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Target/Cpp/CppEmitter.h"
Expand Down Expand Up @@ -45,6 +47,8 @@ void registerToCppTranslation() {
// clang-format off
registry.insert<cf::ControlFlowDialect,
emitc::EmitCDialect,
memref::MemRefDialect,
math::MathDialect,
func::FuncDialect>();
// clang-format on
});
Expand Down
49 changes: 45 additions & 4 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir-c/BuiltinAttributes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand Down Expand Up @@ -1138,8 +1141,46 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}

CppEmitter::Scope scope(emitter);
CppEmitter::Scope classScope(emitter);
raw_indented_ostream &os = emitter.ostream();
os << "class MyClass final {\n";
auto argAttrs = functionOp.getArgAttrs();

std::map<std::string, Value> fields;
if (argAttrs)
for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) {
if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
fields[name] = v;
os << " ";
if (failed(emitter.emitType(functionOp.getLoc(), v.getType())))
return failure();
os << " " << emitter.getOrCreateName(v) << ";\n";
}
}

for (auto & r : functionOp->getRegions())
for (auto &b : r.getBlocks())
for (auto &opt : b.getOperations())
if (auto alloc = dyn_cast<memref::AllocOp>(opt)) {
auto name = emitter.getOrCreateName(alloc).str();
fields[name] = alloc;
if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
return failure();
os << " [" << alloc.getType().getNumElements() <<"] ";
os << " " << name << ";\n";
}
os << " std::map<std::string, char*> _buffer_map {";
for (auto &[n,v]:fields)
os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
os << " };\n";
os << " char* getBufferForName(const std::string& name) const {\n";
os << " auto it = _buffer_map.find(name);\n";
os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n";
os << " }\n";
CppEmitter::Scope scope(emitter);

if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
Expand All @@ -1160,13 +1201,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << ");";
return success();
}
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
// if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
// return failure();
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
os << "}\n";

os << "};\n";
return success();
}

Expand Down
Loading