forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
caffe2
乔龙飞 edited this page Jun 15, 2017
·
6 revisions
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
// In default, we will return a simple network that just runs all operators
// sequentially.
if (!net_def.has_type()) {
return make_unique<SimpleNet>(net_def, ws);
}
return NetRegistry()->Create(net_def.type(), net_def, ws);
}// Network definition.
message NetDef {
optional string name = 1; // the network's name
repeated OperatorDef op = 2;
optional string type = 3;
optional int32 num_workers = 4 [deprecated=true];
optional DeviceOption device_option = 5;
repeated Argument arg = 6;
repeated string external_input = 7;
repeated string external_output = 8;
}class SimpleNet : public NetBase {
public:
SimpleNet(const NetDef& net_def, Workspace* ws);
bool Run() override;
bool RunAsync() override;
vector<OperatorBase*> getOperators();
}
protected:
vector<unique_ptr<OperatorBase> > operators_;
};bool SimpleNet::Run() {
for (auto& op : operators_) {
if (!op->Run()) {
return false;
}
}
return true;
}bool SimpleNet::RunAsync() {
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
if (!op->RunAsync()) {
return false;
}
}
return true;
}class DAGNetBase : public NetBase {
public:
using ExecutionChains = std::unordered_map<int, std::vector<int>>;
DAGNetBase(const NetDef& net_def, Workspace* ws);
~DAGNetBase();
bool Run() override;
// WorkerFunction() is a function wrapper to allow us to run worker threads.
// It checks out one ready-to-run operator from the job queue, runs it,
// notifies all its children, and for any children that is ready, enqueues
// it to the job queue.
void WorkerFunction();
protected:
virtual bool RunAt(const std::vector<int>& chain) = 0;
vector<internal::OperatorNode> operator_nodes_;
ExecutionChains execution_chains_;
vector<int> initial_frontier_;
SimpleQueue<int> job_queue_;
std::vector<std::thread> workers_;
int num_workers_;
int remaining_ops_;
bool success_;
std::mutex remaining_ops_mutex_;
std::condition_variable cv_;
std::mutex run_in_progress_;
};struct OperatorNode {
unique_ptr<OperatorBase> operator_;
vector<int> children_;
vector<int> parents_;
std::atomic<int> runtime_parent_count_;
bool is_chain_start_ = false;
};class OperatorBase {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
virtual ~OperatorBase() noexcept {}
virtual bool Run(int /* unused */ stream_id = 0) {}
private:
OperatorDef operator_def_;
vector<const Blob*> inputs_;
vector<Blob*> outputs_;
};bool DAGNetBase::Run() {
// 1. 初始化job queue
for (auto& value : initial_frontier_) {
job_queue_.Push(value);
}
//2. 等待执行完所有的任务
while (remaining_ops_ > 0) {
VLOG(2) << "Remaining ops to run: " << remaining_ops_;
cv_.wait(mutex_lock);
}
// 3. 启动任务线程
for (auto i = workers_.size(); i < num_workers_; ++i) {
VLOG(1) << "Start worker #" << i;
workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
}
return success_;
}bool DAGNetBase::Run() {
// 1, 初始化 job queue.
job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
// 2, 启动工作线程
for (auto i = 0; i < num_workers_to_start; i++) {
workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
}
// 3. 开始向job queue塞任务。
for (auto& value : initial_frontier_) {
job_queue_->Push(value);
}
// 4. 等待任务完成
for (;;) {
if (remaining_ops_ == 0 || !success_) {
break;
}
cv_.wait(mutex_lock);
}
return success_;
}void DAGNetBase::WorkerFunction() {
while (true) {
// 1. 取出任务
int idx = 0;
if (!job_queue_.Pop(&idx)) {
return;
}
const auto& chain = execution_chains_[idx];
// 2. 运行任务
bool this_success = RunAt(execution_chains_[idx]);
for (const auto idx : chain) {
if (operator_nodes_[child].is_chain_start_) {
// 3. 放入新任务
job_queue_.Push(child);
}
}
}
cv_.notify_one();
}
}SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws)
: NetBase(net_def, ws) {
for (const OperatorDef& operator_def : net_def.op()) {
operators_.emplace_back(CreateOperator(operator_def, ws));
}
}unique_ptr<OperatorBase> CreateOperator(
const OperatorDef& operator_def, Workspace* ws) {
// 1, 检查 OpSchema 是否合法
auto* schema = OpSchemaRegistry::Schema(operator_def.type());
schema->Verify(operator_def);
// 2, 如果用户指定了engine,如果是gpu operator,engine就是CUDA,其他还可能是`EIGEN`,
// 例如: REGISTER_CPU_OPERATOR_WITH_ENGINE(Conv, EIGEN, EigenConvOp<float>);
if (operator_def.engine().size()) {
vector<string> engine_choices = split(',', operator_def.engine());
for (const string& engine : engine_choices) {
string key = operator_def.type() + "_ENGINE_" + engine;
auto op = TryCreateOperator(key, operator_def, ws);
if (op) {
return op;
}
}
}
// 3. 传建一个默认engine的,通过device_type
auto op = TryCreateOperator(operator_def.type(), operator_def, ws);
return op;
}