Skip to content
乔龙飞 edited this page Jun 15, 2017 · 6 revisions
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
  // In default, we will return a simple network that just runs all operators
  // sequentially.
  if (!net_def.has_type()) {
    return make_unique<SimpleNet>(net_def, ws);
  }
  return NetRegistry()->Create(net_def.type(), net_def, ws);
}
// Network definition.
message NetDef {
  optional string name = 1; // the network's name
  repeated OperatorDef op = 2;
  optional string type = 3;
  optional int32 num_workers = 4 [deprecated=true];
  optional DeviceOption device_option = 5;
  repeated Argument arg = 6;
  repeated string external_input = 7;
  repeated string external_output = 8;
}
class SimpleNet : public NetBase {
 public:
  SimpleNet(const NetDef& net_def, Workspace* ws);
  bool Run() override;
  bool RunAsync() override;
  vector<OperatorBase*> getOperators();
  }

 protected:
  vector<unique_ptr<OperatorBase> > operators_;
};
bool SimpleNet::Run() {
  for (auto& op : operators_) {
    if (!op->Run()) {
      return false;
    }
  }
  return true;
}
bool SimpleNet::RunAsync() {
  VLOG(1) << "Running net " << name_;
  for (auto& op : operators_) {
    if (!op->RunAsync()) {
      return false;
    }
  }
  return true;
}
class DAGNetBase : public NetBase {
 public:
  using ExecutionChains = std::unordered_map<int, std::vector<int>>;
  DAGNetBase(const NetDef& net_def, Workspace* ws);
  ~DAGNetBase();
  bool Run() override;
  // WorkerFunction() is a function wrapper to allow us to run worker threads.
  // It checks out one ready-to-run operator from the job queue, runs it,
  // notifies all its children, and for any children that is ready, enqueues
  // it to the job queue.
  void WorkerFunction();
 protected:
  virtual bool RunAt(const std::vector<int>& chain) = 0;
  vector<internal::OperatorNode> operator_nodes_;
  ExecutionChains execution_chains_;
  vector<int> initial_frontier_;
  SimpleQueue<int> job_queue_;
  std::vector<std::thread> workers_;
  int num_workers_;
  int remaining_ops_;

  bool success_;
  std::mutex remaining_ops_mutex_;
  std::condition_variable cv_;
  std::mutex run_in_progress_;
};
Clone this wiki locally