forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
nnvm graph
乔龙飞 edited this page Jun 15, 2017
·
8 revisions
class Graph {
public:
// outputs of the computation graph.
std::vector<NodeEntry> outputs;
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
};class IndexedGraph {
public:
/*! \brief represents a data in the graph */
struct NodeEntry {
/*! \brief the source node id in the computation graph */
uint32_t node_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief version of the node */
uint32_t version;
};
/*! \brief Node data structure in IndexedGraph */
struct Node {
/*! \brief pointer to the source node */
const nnvm::Node* source;
/*! \brief inputs to the node */
array_view<NodeEntry> inputs;
/*! \brief control flow dependencies to the node */
array_view<uint32_t> control_deps;
};
/*! \return number of nodes in the graph */
inline size_t num_nodes() const {
return nodes_.size();
}
/*! \return total number of NodeEntry in the graph */
inline size_t num_node_entries() const {
return entry_rptr_.back();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline uint32_t entry_id(uint32_t node_id, uint32_t index) const {
return entry_rptr_[node_id] + index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const NodeEntry& e) const {
return entry_rptr_[e.node_id] + e.index;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline uint32_t entry_id(const nnvm::NodeEntry& e) const {
return entry_rptr_[node_id(e.node.get())] + e.index;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline uint32_t node_id(const nnvm::Node* node) const {
return node2index_.at(node);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](uint32_t node_id) const {
return nodes_[node_id];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline const Node& operator[](const nnvm::Node* node) const {
return nodes_[node_id(node)];
}
/*! \return list of argument nodes */
inline const std::vector<uint32_t>& input_nodes() const {
return input_nodes_;
}
/*! \return list of mutable nodes */
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return mutable_input_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
/*! \return whether a node is existed in the indexed graph */
inline bool exist(const nnvm::Node* node) const {
return node2index_.count(node);
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;
private:
friend class Graph;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// Node pointers in CSR structure.
std::vector<Node> nodes_;
// Index to all input nodes.
std::vector<uint32_t> input_nodes_;
// Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
std::vector<size_t> entry_rptr_;
// space to store input entries of each
std::vector<NodeEntry> input_entries_;
// control flow dependencies
std::vector<uint32_t> control_deps_;
};