forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
caffe2
乔龙飞 edited this page Jun 15, 2017
·
6 revisions
unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
// In default, we will return a simple network that just runs all operators
// sequentially.
if (!net_def.has_type()) {
return make_unique<SimpleNet>(net_def, ws);
}
return NetRegistry()->Create(net_def.type(), net_def, ws);
}// Network definition.
message NetDef {
optional string name = 1; // the network's name
repeated OperatorDef op = 2;
optional string type = 3;
optional int32 num_workers = 4 [deprecated=true];
optional DeviceOption device_option = 5;
repeated Argument arg = 6;
repeated string external_input = 7;
repeated string external_output = 8;
}class SimpleNet : public NetBase {
public:
SimpleNet(const NetDef& net_def, Workspace* ws);
bool Run() override;
bool RunAsync() override;
vector<OperatorBase*> getOperators();
}
protected:
vector<unique_ptr<OperatorBase> > operators_;
};bool SimpleNet::Run() {
for (auto& op : operators_) {
if (!op->Run()) {
return false;
}
}
return true;
}bool SimpleNet::RunAsync() {
VLOG(1) << "Running net " << name_;
for (auto& op : operators_) {
if (!op->RunAsync()) {
return false;
}
}
return true;
}class DAGNetBase : public NetBase {
public:
using ExecutionChains = std::unordered_map<int, std::vector<int>>;
DAGNetBase(const NetDef& net_def, Workspace* ws);
~DAGNetBase();
bool Run() override;
// WorkerFunction() is a function wrapper to allow us to run worker threads.
// It checks out one ready-to-run operator from the job queue, runs it,
// notifies all its children, and for any children that is ready, enqueues
// it to the job queue.
void WorkerFunction();
protected:
virtual bool RunAt(const std::vector<int>& chain) = 0;
vector<internal::OperatorNode> operator_nodes_;
ExecutionChains execution_chains_;
vector<int> initial_frontier_;
SimpleQueue<int> job_queue_;
std::vector<std::thread> workers_;
int num_workers_;
int remaining_ops_;
bool success_;
std::mutex remaining_ops_mutex_;
std::condition_variable cv_;
std::mutex run_in_progress_;
};