@@ -20,48 +20,50 @@ limitations under the License.
2020#include < vector>
2121#include < stdexcept>
2222#include " osp/concepts/computational_dag_concept.hpp"
23+ #include " osp/graph_algorithms/directed_graph_edge_desc_util.hpp"
2324
2425namespace osp {
2526
27+ template <typename index_type = size_t , typename workw_type = int , typename memw_type = int , typename commw_type = int >
2628class Hypergraph {
2729
2830 public:
2931
3032 Hypergraph () = default ;
3133
32- Hypergraph (unsigned num_vertices_, unsigned num_hyperedges_)
34+ Hypergraph (index_type num_vertices_, index_type num_hyperedges_)
3335 : Num_vertices(num_vertices_), Num_hyperedges(num_hyperedges_), vertex_work_weights(num_vertices_, 1 ),
3436 vertex_memory_weights (num_vertices_, 1 ), hyperedge_weights(num_hyperedges_, 1 ),
3537 incident_hyperedges_to_vertex(num_vertices_), vertices_in_hyperedge(num_hyperedges_){}
3638
37- Hypergraph (const Hypergraph &other) = default;
38- Hypergraph &operator =(const Hypergraph &other) = default ;
39+ Hypergraph (const Hypergraph<index_type, workw_type, memw_type, commw_type> &other) = default;
40+ Hypergraph &operator =(const Hypergraph<index_type, workw_type, memw_type, commw_type> &other) = default ;
3941
4042 virtual ~Hypergraph () = default ;
4143
42- inline unsigned num_vertices () const { return Num_vertices; }
43- inline unsigned num_hyperedges () const { return Num_hyperedges; }
44- inline unsigned num_pins () const { return Num_pins; }
45- inline int get_vertex_work_weight (unsigned node) const { return vertex_work_weights[node]; }
46- inline int get_vertex_memory_weight (unsigned node) const { return vertex_memory_weights[node]; }
47- inline int get_hyperedge_weight (unsigned hyperedge) const { return hyperedge_weights[hyperedge]; }
44+ inline index_type num_vertices () const { return Num_vertices; }
45+ inline index_type num_hyperedges () const { return Num_hyperedges; }
46+ inline index_type num_pins () const { return Num_pins; }
47+ inline workw_type get_vertex_work_weight (index_type node) const { return vertex_work_weights[node]; }
48+ inline memw_type get_vertex_memory_weight (index_type node) const { return vertex_memory_weights[node]; }
49+ inline commw_type get_hyperedge_weight (index_type hyperedge) const { return hyperedge_weights[hyperedge]; }
4850
49- void add_pin (unsigned vertex_idx, unsigned hyperedge_idx);
50- void add_vertex (int work_weight = 1 , int memory_weight = 1 );
51- void add_empty_hyperedge (int weight = 1 );
52- void add_hyperedge (const std::vector<unsigned >& pins, int weight = 1 );
53- void set_vertex_work_weight (unsigned vertex_idx, int weight);
54- void set_vertex_memory_weight (unsigned vertex_idx, int weight);
55- void set_hyperedge_weight (unsigned hyperedge_idx, int weight);
51+ void add_pin (index_type vertex_idx, index_type hyperedge_idx);
52+ void add_vertex (workw_type work_weight = 1 , memw_type memory_weight = 1 );
53+ void add_empty_hyperedge (commw_type weight = 1 );
54+ void add_hyperedge (const std::vector<index_type >& pins, commw_type weight = 1 );
55+ void set_vertex_work_weight (index_type vertex_idx, workw_type weight);
56+ void set_vertex_memory_weight (index_type vertex_idx, memw_type weight);
57+ void set_hyperedge_weight (index_type hyperedge_idx, commw_type weight);
5658
57- int compute_total_vertex_work_weight () const ;
58- int compute_total_vertex_memory_weight () const ;
59+ workw_type compute_total_vertex_work_weight () const ;
60+ memw_type compute_total_vertex_memory_weight () const ;
5961
6062 void clear ();
61- void reset (unsigned num_vertices_, unsigned num_hyperedges_);
63+ void reset (index_type num_vertices_, index_type num_hyperedges_);
6264
63- inline const std::vector<unsigned > &get_incident_hyperedges (unsigned vertex) const { return incident_hyperedges_to_vertex[vertex]; }
64- inline const std::vector<unsigned > &get_vertices_in_hyperedge (unsigned hyperedge) const { return vertices_in_hyperedge[hyperedge]; }
65+ inline const std::vector<index_type > &get_incident_hyperedges (index_type vertex) const { return incident_hyperedges_to_vertex[vertex]; }
66+ inline const std::vector<index_type > &get_vertices_in_hyperedge (index_type hyperedge) const { return vertices_in_hyperedge[hyperedge]; }
6567
6668 template <typename Graph_t>
6769 void convert_from_cdag_as_dag (const Graph_t& dag);
@@ -70,17 +72,18 @@ class Hypergraph {
7072 void convert_from_cdag_as_hyperdag (const Graph_t& dag);
7173
7274 private:
73- unsigned Num_vertices = 0 , Num_hyperedges = 0 , Num_pins = 0 ;
75+ index_type Num_vertices = 0 , Num_hyperedges = 0 , Num_pins = 0 ;
7476
75- std::vector<int > vertex_work_weights;
76- std::vector<int > vertex_memory_weights;
77- std::vector<int > hyperedge_weights;
77+ std::vector<workw_type > vertex_work_weights;
78+ std::vector<memw_type > vertex_memory_weights;
79+ std::vector<commw_type > hyperedge_weights;
7880
79- std::vector<std::vector<unsigned >> incident_hyperedges_to_vertex;
80- std::vector<std::vector<unsigned >> vertices_in_hyperedge;
81+ std::vector<std::vector<index_type >> incident_hyperedges_to_vertex;
82+ std::vector<std::vector<index_type >> vertices_in_hyperedge;
8183};
8284
83- void Hypergraph::add_pin (unsigned vertex_idx, unsigned hyperedge_idx)
85+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
86+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_pin(index_type vertex_idx, index_type hyperedge_idx)
8487{
8588 if (vertex_idx >= Num_vertices)
8689 {
@@ -97,72 +100,81 @@ void Hypergraph::add_pin(unsigned vertex_idx, unsigned hyperedge_idx)
97100 }
98101}
99102
100- void Hypergraph::add_vertex (int work_weight, int memory_weight)
103+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
104+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_vertex(workw_type work_weight, memw_type memory_weight)
101105{
102106 vertex_work_weights.push_back (work_weight);
103107 vertex_memory_weights.push_back (memory_weight);
104108 incident_hyperedges_to_vertex.emplace_back ();
105109 ++Num_vertices;
106110}
107111
108- void Hypergraph::add_empty_hyperedge (int weight)
112+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
113+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_empty_hyperedge(commw_type weight)
109114{
110115 vertices_in_hyperedge.emplace_back ();
111116 hyperedge_weights.push_back (weight);
112117 ++Num_hyperedges;
113118}
114119
115- void Hypergraph::add_hyperedge (const std::vector<unsigned >& pins, int weight)
120+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
121+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_hyperedge(const std::vector<index_type>& pins, commw_type weight)
116122{
117123 vertices_in_hyperedge.emplace_back (pins);
118124 hyperedge_weights.push_back (weight);
119- for (unsigned vertex : pins)
125+ for (index_type vertex : pins)
120126 incident_hyperedges_to_vertex[vertex].push_back (Num_hyperedges);
121127 ++Num_hyperedges;
122- Num_pins += static_cast <unsigned >(pins.size ());
128+ Num_pins += static_cast <index_type >(pins.size ());
123129}
124130
125- void Hypergraph::set_vertex_work_weight (unsigned vertex_idx, int weight)
131+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
132+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_vertex_work_weight(index_type vertex_idx, workw_type weight)
126133{
127134 if (vertex_idx >= Num_vertices)
128135 throw std::invalid_argument (" Invalid Argument while setting vertex weight: vertex index out of range." );
129136 else
130137 vertex_work_weights[vertex_idx] = weight;
131138}
132139
133- void Hypergraph::set_vertex_memory_weight (unsigned vertex_idx, int weight)
140+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
141+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_vertex_memory_weight(index_type vertex_idx, memw_type weight)
134142{
135143 if (vertex_idx >= Num_vertices)
136144 throw std::invalid_argument (" Invalid Argument while setting vertex weight: vertex index out of range." );
137145 else
138146 vertex_memory_weights[vertex_idx] = weight;
139147}
140148
141- void Hypergraph::set_hyperedge_weight (unsigned hyperedge_idx, int weight)
149+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
150+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_hyperedge_weight(index_type hyperedge_idx, commw_type weight)
142151{
143152 if (hyperedge_idx >= Num_hyperedges)
144153 throw std::invalid_argument (" Invalid Argument while setting hyperedge weight: hyepredge index out of range." );
145154 else
146155 hyperedge_weights[hyperedge_idx] = weight;
147156}
148157
149- int Hypergraph::compute_total_vertex_work_weight () const
158+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
159+ workw_type Hypergraph<index_type, workw_type, memw_type, commw_type>::compute_total_vertex_work_weight() const
150160{
151- int total = 0 ;
152- for (unsigned node = 0 ; node < Num_vertices; ++node)
161+ workw_type total = 0 ;
162+ for (index_type node = 0 ; node < Num_vertices; ++node)
153163 total += vertex_work_weights[node];
154164 return total;
155165}
156166
157- int Hypergraph::compute_total_vertex_memory_weight () const
167+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
168+ memw_type Hypergraph<index_type, workw_type, memw_type, commw_type>::compute_total_vertex_memory_weight() const
158169{
159- int total = 0 ;
160- for (unsigned node = 0 ; node < Num_vertices; ++node)
170+ memw_type total = 0 ;
171+ for (index_type node = 0 ; node < Num_vertices; ++node)
161172 total += vertex_memory_weights[node];
162173 return total;
163174}
164175
165- void Hypergraph::clear ()
176+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
177+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::clear()
166178{
167179 Num_vertices = 0 ;
168180 Num_hyperedges = 0 ;
@@ -175,7 +187,8 @@ void Hypergraph::clear()
175187 vertices_in_hyperedge.clear ();
176188}
177189
178- void Hypergraph::reset (unsigned num_vertices_, unsigned num_hyperedges_)
190+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
191+ void Hypergraph<index_type, workw_type, memw_type, commw_type>::reset(index_type num_vertices_, index_type num_hyperedges_)
179192{
180193 clear ();
181194
@@ -189,33 +202,48 @@ void Hypergraph::reset(unsigned num_vertices_, unsigned num_hyperedges_)
189202 vertices_in_hyperedge.resize (num_hyperedges_);
190203}
191204
205+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
192206template <typename Graph_t>
193- void Hypergraph::convert_from_cdag_as_dag (const Graph_t& dag)
207+ void Hypergraph<index_type, workw_type, memw_type, commw_type> ::convert_from_cdag_as_dag(const Graph_t& dag)
194208{
195- reset (static_cast <unsigned >(dag.num_vertices ()), 0 );
209+ static_assert (std::is_same_v<vertex_idx_t <Graph_t>, index_type>, " Index type mismatch, cannot convert DAG to hypergraph." );
210+ static_assert (std::is_same_v<v_workw_t <Graph_t>, workw_type>, " Work weight type mismatch, cannot convert DAG to hypergraph." );
211+ static_assert (std::is_same_v<v_memw_t <Graph_t>, memw_type>, " Memory weight type mismatch, cannot convert DAG to hypergraph." );
212+ static_assert (!has_edge_weights_v<Graph_t> || std::is_same_v<e_commw_t <Graph_t>, commw_type>, " Communication weight type mismatch, cannot convert DAG to hypergraph." );
213+
214+ reset (dag.num_vertices (), 0 );
196215 for (const auto &node : dag.vertices ())
197216 {
198- set_vertex_work_weight (static_cast < unsigned >( node), static_cast < int >( dag.vertex_work_weight (node) ));
199- set_vertex_memory_weight (static_cast < unsigned >( node), static_cast < int >( dag.vertex_mem_weight (node) ));
217+ set_vertex_work_weight (node, dag.vertex_work_weight (node));
218+ set_vertex_memory_weight (node, dag.vertex_mem_weight (node));
200219 for (const auto &child : dag.children (node))
201- add_hyperedge ({static_cast <unsigned >(node), static_cast <unsigned >(child)}); // TODO add edge weights if present
220+ if constexpr (has_edge_weights_v<Graph_t>)
221+ add_hyperedge ({node, child}, dag.edge_comm_weight (edge_desc (node, child, dag).first ));
222+ else
223+ add_hyperedge ({node, child});
202224 }
203225}
204226
227+ template <typename index_type, typename workw_type, typename memw_type, typename commw_type>
205228template <typename Graph_t>
206- void Hypergraph::convert_from_cdag_as_hyperdag (const Graph_t& dag)
229+ void Hypergraph<index_type, workw_type, memw_type, commw_type> ::convert_from_cdag_as_hyperdag(const Graph_t& dag)
207230{
208- reset (static_cast <unsigned >(dag.num_vertices ()), 0 );
231+ static_assert (std::is_same_v<vertex_idx_t <Graph_t>, index_type>, " Index type mismatch, cannot convert DAG to hypergraph." );
232+ static_assert (std::is_same_v<v_workw_t <Graph_t>, workw_type>, " Work weight type mismatch, cannot convert DAG to hypergraph." );
233+ static_assert (std::is_same_v<v_memw_t <Graph_t>, memw_type>, " Memory weight type mismatch, cannot convert DAG to hypergraph." );
234+ static_assert (std::is_same_v<v_commw_t <Graph_t>, commw_type>, " Communication weight type mismatch, cannot convert DAG to hypergraph." );
235+
236+ reset (dag.num_vertices (), 0 );
209237 for (const auto &node : dag.vertices ())
210238 {
211- set_vertex_work_weight (static_cast < unsigned >( node), static_cast < int >( dag.vertex_work_weight (node) ));
212- set_vertex_memory_weight (static_cast < unsigned >( node), static_cast < int >( dag.vertex_mem_weight (node) ));
239+ set_vertex_work_weight (node, dag.vertex_work_weight (node));
240+ set_vertex_memory_weight (node, dag.vertex_mem_weight (node));
213241 if (dag.out_degree (node) == 0 )
214242 continue ;
215- std::vector<unsigned > new_hyperedge ({static_cast < unsigned >( node) });
243+ std::vector<index_type > new_hyperedge ({node});
216244 for (const auto &child : dag.children (node))
217- new_hyperedge.push_back (static_cast < unsigned >( child) );
218- add_hyperedge (new_hyperedge, static_cast < int >( dag.vertex_comm_weight (node) ));
245+ new_hyperedge.push_back (child);
246+ add_hyperedge (new_hyperedge, dag.vertex_comm_weight (node));
219247 }
220248}
221249
0 commit comments