forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
caffe2
乔龙飞 edited this page Jun 15, 2017
·
6 revisions
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
// In default, we will return a simple network that just runs all operators
// sequentially.
if (!net_def.has_type()) {
return make_unique<SimpleNet>(net_def, ws);
}
return NetRegistry()->Create(net_def.type(), net_def, ws);
}// Network definition.
message NetDef {
optional string name = 1; // the network's name
repeated OperatorDef op = 2;
optional string type = 3;
optional int32 num_workers = 4 [deprecated=true];
optional DeviceOption device_option = 5;
repeated Argument arg = 6;
repeated string external_input = 7;
repeated string external_output = 8;
}class SimpleNet : public NetBase {
public:
SimpleNet(const NetDef& net_def, Workspace* ws);
bool Run() override;
bool RunAsync() override;
vector<OperatorBase*> getOperators();
}
protected:
vector<unique_ptr<OperatorBase> > operators_;
};bool SimpleNet::Run() {
for (auto& op : operators_) {
if (!op->Run()) {
return false;
}
}
return true;
}bool SimpleNet::RunAsync() {
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
if (!op->RunAsync()) {
return false;
}
}
return true;
}class DAGNetBase : public NetBase {
public:
using ExecutionChains = std::unordered_map<int, std::vector<int>>;
DAGNetBase(const NetDef& net_def, Workspace* ws);
~DAGNetBase();
bool Run() override;
// WorkerFunction() is a function wrapper to allow us to run worker threads.
// It checks out one ready-to-run operator from the job queue, runs it,
// notifies all its children, and for any children that is ready, enqueues
// it to the job queue.
void WorkerFunction();
protected:
virtual bool RunAt(const std::vector<int>& chain) = 0;
vector<internal::OperatorNode> operator_nodes_;
ExecutionChains execution_chains_;
vector<int> initial_frontier_;
SimpleQueue<int> job_queue_;
std::vector<std::thread> workers_;
int num_workers_;
int remaining_ops_;
bool success_;
std::mutex remaining_ops_mutex_;
std::condition_variable cv_;
std::mutex run_in_progress_;
};struct OperatorNode {
unique_ptr<OperatorBase> operator_;
vector<int> children_;
vector<int> parents_;
std::atomic<int> runtime_parent_count_;
bool is_chain_start_ = false;
};class OperatorBase {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
virtual ~OperatorBase() noexcept {}
virtual bool Run(int /* unused */ stream_id = 0) {}
private:
OperatorDef operator_def_;
vector<const Blob*> inputs_;
vector<Blob*> outputs_;
};bool DAGNetBase::Run() {
// 1. 初始化job queue
for (auto& value : initial_frontier_) {
job_queue_.Push(value);
}
//2. 等待执行完所有的任务
while (remaining_ops_ > 0) {
VLOG(2) << "Remaining ops to run: " << remaining_ops_;
cv_.wait(mutex_lock);
}
// 3. 启动任务线程
for (auto i = workers_.size(); i < num_workers_; ++i) {
VLOG(1) << "Start worker #" << i;
workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
}
return success_;
}void DAGNetBase::WorkerFunction() {
// WorkerFunctions() is an infinite loop until there are no more jobs to run.
while (true) {
int idx = 0;
if (!job_queue_.Pop(&idx)) {
return;
}
const auto& chain = execution_chains_[idx];
bool this_success = RunAt(execution_chains_[idx]);
for (const auto idx : chain) {
if (operator_nodes_[child].is_chain_start_) {
job_queue_.Push(child);
}
}
}
cv_.notify_one();
}
}