Skip to content
乔龙飞 edited this page Jun 15, 2017 · 6 revisions
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
  // In default, we will return a simple network that just runs all operators
  // sequentially.
  if (!net_def.has_type()) {
    return make_unique<SimpleNet>(net_def, ws);
  }
  return NetRegistry()->Create(net_def.type(), net_def, ws);
}
// Network definition.
message NetDef {
  optional string name = 1; // the network's name
  repeated OperatorDef op = 2;
  optional string type = 3;
  optional int32 num_workers = 4 [deprecated=true];
  optional DeviceOption device_option = 5;
  repeated Argument arg = 6;
  repeated string external_input = 7;
  repeated string external_output = 8;
}
class SimpleNet : public NetBase {
 public:
  SimpleNet(const NetDef& net_def, Workspace* ws);
  bool Run() override;
  bool RunAsync() override;
  vector<OperatorBase*> getOperators();
  }

 protected:
  vector<unique_ptr<OperatorBase> > operators_;
};
bool SimpleNet::Run() {
  for (auto& op : operators_) {
    if (!op->Run()) {
      return false;
    }
  }
  return true;
}
bool SimpleNet::RunAsync() {
  VLOG(1) << "Running net " << name_;
  for (auto& op : operators_) {
    if (!op->RunAsync()) {
      return false;
    }
  }
  return true;
}
class DAGNetBase : public NetBase {
 public:
  using ExecutionChains = std::unordered_map<int, std::vector<int>>;
  DAGNetBase(const NetDef& net_def, Workspace* ws);
  ~DAGNetBase();
  bool Run() override;
  // WorkerFunction() is a function wrapper to allow us to run worker threads.
  // It checks out one ready-to-run operator from the job queue, runs it,
  // notifies all its children, and for any children that is ready, enqueues
  // it to the job queue.
  void WorkerFunction();
 protected:
  virtual bool RunAt(const std::vector<int>& chain) = 0;
  vector<internal::OperatorNode> operator_nodes_;
  ExecutionChains execution_chains_;
  vector<int> initial_frontier_;
  SimpleQueue<int> job_queue_;
  std::vector<std::thread> workers_;
  int num_workers_;
  int remaining_ops_;

  bool success_;
  std::mutex remaining_ops_mutex_;
  std::condition_variable cv_;
  std::mutex run_in_progress_;
};
struct OperatorNode {
  unique_ptr<OperatorBase> operator_;
  vector<int> children_;
  vector<int> parents_;
  std::atomic<int> runtime_parent_count_;
  bool is_chain_start_ = false;
};
class OperatorBase {
 public:
  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
  virtual ~OperatorBase() noexcept {}

  virtual bool Run(int /* unused */ stream_id = 0) {}

 private:
  OperatorDef operator_def_;
  vector<const Blob*> inputs_;
  vector<Blob*> outputs_;
};
bool DAGNetBase::Run() {
  // 1. 初始化job queue
  for (auto& value : initial_frontier_) {
    job_queue_.Push(value);
  }

  //2. 等待执行完所有的任务
  while (remaining_ops_ > 0) {
    VLOG(2) << "Remaining ops to run: " << remaining_ops_;
    cv_.wait(mutex_lock);
  }

  // 3. 启动任务线程
  for (auto i = workers_.size(); i < num_workers_; ++i) {
    VLOG(1) << "Start worker #" << i;
    workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
  }

  return success_;
}
bool DAGNetBase::Run() {
  // 1, 初始化 job queue.
  job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
  // 2, 启动工作线程
  for (auto i = 0; i < num_workers_to_start; i++) {
    workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
  }
  // 3. 开始向job queue塞任务。
  for (auto& value : initial_frontier_) {
    job_queue_->Push(value);
  }
  // 4. 等待任务完成
  for (;;) {
      if (remaining_ops_ == 0 || !success_) {
        break;
      }
      cv_.wait(mutex_lock);
  }
  return success_;
}
void DAGNetBase::WorkerFunction() {
  while (true) {
    // 1. 取出任务
    int idx = 0;
    if (!job_queue_.Pop(&idx)) {
      return;
    }
    const auto& chain = execution_chains_[idx];
    // 2. 运行任务
    bool this_success = RunAt(execution_chains_[idx]);
    for (const auto idx : chain) {
        if (operator_nodes_[child].is_chain_start_) {
          // 3. 放入新任务
          job_queue_.Push(child);
        }
      }
    }
    cv_.notify_one();
  }
}
SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws)
    : NetBase(net_def, ws) {
  for (const OperatorDef& operator_def : net_def.op()) {
    operators_.emplace_back(CreateOperator(operator_def, ws));
  }
}
unique_ptr<OperatorBase> CreateOperator(
    const OperatorDef& operator_def, Workspace* ws) {
  // 1, 检查 OpSchema 是否合法
  auto* schema = OpSchemaRegistry::Schema(operator_def.type());
  schema->Verify(operator_def);

  // 2, 如果用户指定了engine,如果是gpu operator,engine就是CUDA,其他还可能是`EIGEN`,
  // 例如: REGISTER_CPU_OPERATOR_WITH_ENGINE(Conv, EIGEN, EigenConvOp<float>);
  if (operator_def.engine().size()) {
    vector<string> engine_choices = split(',', operator_def.engine());
    for (const string& engine : engine_choices) {
      string key = operator_def.type() + "_ENGINE_" + engine;
      auto op = TryCreateOperator(key, operator_def, ws);
      if (op) {
        return op;
      }
    }
  }

  // 3. 传建一个默认engine的,通过device_type
  auto op = TryCreateOperator(operator_def.type(), operator_def, ws);
  return op;
}
Clone this wiki locally