diff --git a/MLModelRunner/ONNXModelRunner/ONNXModelRunner.cpp b/MLModelRunner/ONNXModelRunner/ONNXModelRunner.cpp old mode 100755 new mode 100644 index 2bb9450e..2474c335 --- a/MLModelRunner/ONNXModelRunner/ONNXModelRunner.cpp +++ b/MLModelRunner/ONNXModelRunner/ONNXModelRunner.cpp @@ -13,7 +13,10 @@ //===----------------------------------------------------------------------===// #include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h" +//#include "MLModelRunner/MLModelRunner.h" +#include "MLModelRunner/MLModelRunner.h" #include "SerDes/baseSerDes.h" +#include "llvm/Support/raw_ostream.h" using namespace llvm; namespace MLBridge { @@ -34,14 +37,26 @@ void ONNXModelRunner::addAgent(Agent *agent, std::string name) { } } +void passAgentInfo(std::string mode, std::string agentName, int action) { + std::error_code EC; + llvm::raw_fd_ostream fileStream("test-raw.txt", EC, llvm::sys::fs::OF_Append); + fileStream << mode << ": " << agentName << ": " << action << "\n"; +} + void ONNXModelRunner::computeAction(Observation &obs) { + // std::error_code EC; + // llvm::raw_fd_ostream fileStream("test-raw.txt", EC, + // llvm::sys::fs::OF_Append); while (true) { Action action; // current agent auto current_agent = this->agents[this->env->getNextAgent()]; action = current_agent->computeAction(obs); + passAgentInfo("input", this->env->getNextAgent(), action); this->env->step(action); + if (this->env->checkDone()) { + passAgentInfo("output", this->env->getNextAgent(), action); std::cout << "Done🎉\n"; break; } diff --git a/include/MLModelRunner/MLModelRunner.h b/include/MLModelRunner/MLModelRunner.h index d99a6c6f..2b87e692 100644 --- a/include/MLModelRunner/MLModelRunner.h +++ b/include/MLModelRunner/MLModelRunner.h @@ -39,19 +39,24 @@ #include "SerDes/baseSerDes.h" #include "SerDes/bitstreamSerDes.h" #include "SerDes/jsonSerDes.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" #include +#include #include #include +#include +#include #include #include #ifndef C_LIBRARY #include "SerDes/protobufSerDes.h" #include "SerDes/tensorflowSerDes.h" +#include #endif namespace MLBridge { - /// MLModelRunner - The main interface for interacting with the ML models. class MLModelRunner { public: @@ -78,6 +83,28 @@ class MLModelRunner { memcpy(ret, res, SerDes->getMessageLength()); dataSize = SerDes->getMessageLength() / sizeof(BaseType); data = ret; + std::error_code EC; + llvm::raw_fd_ostream fileStream("test-raw.txt", EC, + llvm::sys::fs::OF_Append); + dumpOutput(fileStream, ret, dataSize); + } + + template + void dumpOutput(llvm::raw_ostream &OS, T output_vec, int DataSize) { + + OS << "Dumping output" + << ": "; + for (auto i = 0; i < DataSize; i++) { + OS << output_vec[i] << " "; + } + OS << "\n"; + } + + template + void dumpOuput(llvm::raw_ostream &OS, T &var1, int DataSize) { + OS << "Dumping output" + << ": "; + OS << var1 << "\n"; } /// Type of the MLModelRunner @@ -87,7 +114,37 @@ class MLModelRunner { BaseSerDes::Kind getSerDesKind() const { return SerDesType; } virtual void requestExit() = 0; + std::promise *exit_requested; + + template + void passMetaInfo(llvm::raw_ostream &OS, std::pair &var1, + std::pair &...var2) { + OS << var1.first << ": " << var1.second << "\n"; + passMetaInfo(var2...); + } + + template + void dumpFeature(llvm::raw_ostream &OS, std::pair &var1) { + OS << "Dumping input" + << ": "; + OS << var1.first << ": " << var1.second << "\n"; + } + + template + void dumpFeature(llvm::raw_ostream &OS, + std::pair> &var1) { + OS << "Dumping input" + << ": "; + OS << var1.first << ": "; + for (const auto &elem : var1.second) { + OS << elem << " "; + } + OS << "\n"; + } + void dumpFeature( + llvm::raw_ostream &OS, + std::pair> &var1) {} /// User-facing interface for setting the features to be sent to the model. /// The features are passed as a list of key-value pairs. /// The key is the name of the feature and the value is the value of the @@ -96,6 +153,10 @@ class MLModelRunner { void populateFeatures(std::pair &var1, std::pair &...var2) { SerDes->setFeature(var1.first, var1.second); + std::error_code EC; + llvm::raw_fd_ostream fileStream("test-raw.txt", EC, + llvm::sys::fs::OF_Append); + dumpFeature(fileStream, var1); populateFeatures(var2...); }