Skip to content

Commit d901d26

Browse files
templating the partitioning part
1 parent da686e2 commit d901d26

12 files changed

+335
-281
lines changed

include/osp/auxiliary/io/mtx_hypergraph_file_reader.hpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace osp {
3333
namespace file_reader {
3434

3535
// reads a matrix into Hypergraph format, where nonzeros are vertices, and rows/columns are hyperedges
36-
bool readHypergraphMartixMarketFormat(std::ifstream& infile, Hypergraph& hgraph) {
36+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
37+
bool readHypergraphMartixMarketFormat(std::ifstream& infile, Hypergraph<index_type, workw_type, memw_type, commw_type>& hgraph) {
3738

3839
std::string line;
3940

@@ -68,23 +69,16 @@ bool readHypergraphMartixMarketFormat(std::ifstream& infile, Hypergraph& hgraph)
6869
return false;
6970
}
7071

71-
const unsigned num_nodes = static_cast<unsigned>(nEntries);
72-
if (num_nodes > std::numeric_limits<unsigned>::max()) {
73-
std::cerr << "Error: Matrix dimension too large for vertex type.\n";
74-
return false;
75-
}
76-
77-
std::vector<int> node_work_wts(num_nodes, 0);
78-
std::vector<int> node_comm_wts(num_nodes, 1);
72+
const index_type num_nodes = static_cast<index_type>(nEntries);
7973

8074
hgraph.reset(num_nodes, 0);
81-
for (unsigned node = 0; node < num_nodes; ++node) {
82-
hgraph.set_vertex_work_weight(node, 1);
83-
hgraph.set_vertex_memory_weight(node, 1);
75+
for (index_type node = 0; node < num_nodes; ++node) {
76+
hgraph.set_vertex_work_weight(node, static_cast<workw_type>(1));
77+
hgraph.set_vertex_memory_weight(node, static_cast<memw_type>(1));
8478
}
8579

86-
std::vector<std::vector<unsigned>> row_hyperedges(static_cast<unsigned>(M_row));
87-
std::vector<std::vector<unsigned>> column_hyperedges(static_cast<unsigned>(M_col));
80+
std::vector<std::vector<index_type>> row_hyperedges(static_cast<index_type>(M_row));
81+
std::vector<std::vector<index_type>> column_hyperedges(static_cast<index_type>(M_col));
8882

8983
int entries_read = 0;
9084
while (entries_read < nEntries && std::getline(infile, line)) {
@@ -110,13 +104,13 @@ bool readHypergraphMartixMarketFormat(std::ifstream& infile, Hypergraph& hgraph)
110104
return false;
111105
}
112106

113-
if (static_cast<unsigned>(row) >= num_nodes || static_cast<unsigned>(col) >= num_nodes) {
107+
if (static_cast<index_type>(row) >= num_nodes || static_cast<index_type>(col) >= num_nodes) {
114108
std::cerr << "Error: Index exceeds vertex type limit.\n";
115109
return false;
116110
}
117111

118-
row_hyperedges[static_cast<unsigned>(row)].push_back(static_cast<unsigned>(entries_read));
119-
column_hyperedges[static_cast<unsigned>(col)].push_back(static_cast<unsigned>(entries_read));
112+
row_hyperedges[static_cast<index_type>(row)].push_back(static_cast<index_type>(entries_read));
113+
column_hyperedges[static_cast<index_type>(col)].push_back(static_cast<index_type>(entries_read));
120114

121115
++entries_read;
122116
}
@@ -133,18 +127,19 @@ bool readHypergraphMartixMarketFormat(std::ifstream& infile, Hypergraph& hgraph)
133127
}
134128
}
135129

136-
for(unsigned row = 0; row < static_cast<unsigned>(M_row); ++row)
130+
for(index_type row = 0; row < static_cast<index_type>(M_row); ++row)
137131
if(!row_hyperedges[row].empty())
138132
hgraph.add_hyperedge(row_hyperedges[row]);
139133

140-
for(unsigned col = 0; col < static_cast<unsigned>(M_col); ++col)
134+
for(index_type col = 0; col < static_cast<index_type>(M_col); ++col)
141135
if(!column_hyperedges[col].empty())
142136
hgraph.add_hyperedge(column_hyperedges[col]);
143137

144138
return true;
145139
}
146140

147-
bool readHypergraphMartixMarketFormat(const std::string& filename, Hypergraph& hgraph) {
141+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
142+
bool readHypergraphMartixMarketFormat(const std::string& filename, Hypergraph<index_type, workw_type, memw_type, commw_type>& hgraph) {
148143
// Ensure the file is .mtx format
149144
if (std::filesystem::path(filename).extension() != ".mtx") {
150145
std::cerr << "Error: Only .mtx files are accepted.\n";

include/osp/auxiliary/io/partitioning_file_writer.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,27 @@ limitations under the License.
2525

2626
namespace osp { namespace file_writer {
2727

28-
void write_txt(std::ostream &os, const Partitioning &partition) {
28+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
29+
void write_txt(std::ostream &os, const Partitioning<index_type, workw_type, memw_type, commw_type> &partition) {
2930

3031
os << "\%\% Partitioning for " << partition.getInstance().getNumberOfPartitions() << " parts." << std::endl;
3132

32-
for(unsigned node = 0; node < partition.getInstance().getHypergraph().num_vertices(); ++node)
33+
for(index_type node = 0; node < partition.getInstance().getHypergraph().num_vertices(); ++node)
3334
os << node << " " << partition.assignedPartition(node) << std::endl;
3435
}
3536

36-
void write_txt(const std::string &filename, const Partitioning &partition) {
37+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
38+
void write_txt(const std::string &filename, const Partitioning<index_type, workw_type, memw_type, commw_type> &partition) {
3739
std::ofstream os(filename);
3840
write_txt(os, partition);
3941
}
4042

41-
void write_txt(std::ostream &os, const PartitioningWithReplication &partition) {
43+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
44+
void write_txt(std::ostream &os, const PartitioningWithReplication<index_type, workw_type, memw_type, commw_type> &partition) {
4245

4346
os << "\%\% Partitioning for " << partition.getInstance().getNumberOfPartitions() << " parts with replication." << std::endl;
4447

45-
for(unsigned node = 0; node < partition.getInstance().getHypergraph().num_vertices(); ++node)
48+
for(index_type node = 0; node < partition.getInstance().getHypergraph().num_vertices(); ++node)
4649
{
4750
os << node;
4851
for(unsigned part : partition.assignedPartitions(node))
@@ -51,7 +54,8 @@ void write_txt(std::ostream &os, const PartitioningWithReplication &partition) {
5154
}
5255
}
5356

54-
void write_txt(const std::string &filename, const PartitioningWithReplication &partition) {
57+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
58+
void write_txt(const std::string &filename, const PartitioningWithReplication<index_type, workw_type, memw_type, commw_type> &partition) {
5559
std::ofstream os(filename);
5660
write_txt(os, partition);
5761
}

include/osp/partitioning/model/hypergraph.hpp

Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,50 @@ limitations under the License.
2020
#include <vector>
2121
#include <stdexcept>
2222
#include "osp/concepts/computational_dag_concept.hpp"
23+
#include "osp/graph_algorithms/directed_graph_edge_desc_util.hpp"
2324

2425
namespace osp {
2526

27+
template<typename index_type = size_t, typename workw_type = int, typename memw_type = int, typename commw_type = int>
2628
class Hypergraph {
2729

2830
public:
2931

3032
Hypergraph() = default;
3133

32-
Hypergraph(unsigned num_vertices_, unsigned num_hyperedges_)
34+
Hypergraph(index_type num_vertices_, index_type num_hyperedges_)
3335
: Num_vertices(num_vertices_), Num_hyperedges(num_hyperedges_), vertex_work_weights(num_vertices_, 1),
3436
vertex_memory_weights(num_vertices_, 1), hyperedge_weights(num_hyperedges_, 1),
3537
incident_hyperedges_to_vertex(num_vertices_), vertices_in_hyperedge(num_hyperedges_){}
3638

37-
Hypergraph(const Hypergraph &other) = default;
38-
Hypergraph &operator=(const Hypergraph &other) = default;
39+
Hypergraph(const Hypergraph<index_type, workw_type, memw_type, commw_type> &other) = default;
40+
Hypergraph &operator=(const Hypergraph<index_type, workw_type, memw_type, commw_type> &other) = default;
3941

4042
virtual ~Hypergraph() = default;
4143

42-
inline unsigned num_vertices() const { return Num_vertices; }
43-
inline unsigned num_hyperedges() const { return Num_hyperedges; }
44-
inline unsigned num_pins() const { return Num_pins; }
45-
inline int get_vertex_work_weight(unsigned node) const { return vertex_work_weights[node]; }
46-
inline int get_vertex_memory_weight(unsigned node) const { return vertex_memory_weights[node]; }
47-
inline int get_hyperedge_weight(unsigned hyperedge) const { return hyperedge_weights[hyperedge]; }
44+
inline index_type num_vertices() const { return Num_vertices; }
45+
inline index_type num_hyperedges() const { return Num_hyperedges; }
46+
inline index_type num_pins() const { return Num_pins; }
47+
inline workw_type get_vertex_work_weight(index_type node) const { return vertex_work_weights[node]; }
48+
inline memw_type get_vertex_memory_weight(index_type node) const { return vertex_memory_weights[node]; }
49+
inline commw_type get_hyperedge_weight(index_type hyperedge) const { return hyperedge_weights[hyperedge]; }
4850

49-
void add_pin(unsigned vertex_idx, unsigned hyperedge_idx);
50-
void add_vertex(int work_weight = 1, int memory_weight = 1);
51-
void add_empty_hyperedge(int weight = 1);
52-
void add_hyperedge(const std::vector<unsigned>& pins, int weight = 1);
53-
void set_vertex_work_weight(unsigned vertex_idx, int weight);
54-
void set_vertex_memory_weight(unsigned vertex_idx, int weight);
55-
void set_hyperedge_weight(unsigned hyperedge_idx, int weight);
51+
void add_pin(index_type vertex_idx, index_type hyperedge_idx);
52+
void add_vertex(workw_type work_weight = 1, memw_type memory_weight = 1);
53+
void add_empty_hyperedge(commw_type weight = 1);
54+
void add_hyperedge(const std::vector<index_type>& pins, commw_type weight = 1);
55+
void set_vertex_work_weight(index_type vertex_idx, workw_type weight);
56+
void set_vertex_memory_weight(index_type vertex_idx, memw_type weight);
57+
void set_hyperedge_weight(index_type hyperedge_idx, commw_type weight);
5658

57-
int compute_total_vertex_work_weight() const;
58-
int compute_total_vertex_memory_weight() const;
59+
workw_type compute_total_vertex_work_weight() const;
60+
memw_type compute_total_vertex_memory_weight() const;
5961

6062
void clear();
61-
void reset(unsigned num_vertices_, unsigned num_hyperedges_);
63+
void reset(index_type num_vertices_, index_type num_hyperedges_);
6264

63-
inline const std::vector<unsigned> &get_incident_hyperedges(unsigned vertex) const { return incident_hyperedges_to_vertex[vertex]; }
64-
inline const std::vector<unsigned> &get_vertices_in_hyperedge(unsigned hyperedge) const { return vertices_in_hyperedge[hyperedge]; }
65+
inline const std::vector<index_type> &get_incident_hyperedges(index_type vertex) const { return incident_hyperedges_to_vertex[vertex]; }
66+
inline const std::vector<index_type> &get_vertices_in_hyperedge(index_type hyperedge) const { return vertices_in_hyperedge[hyperedge]; }
6567

6668
template<typename Graph_t>
6769
void convert_from_cdag_as_dag(const Graph_t& dag);
@@ -70,17 +72,18 @@ class Hypergraph {
7072
void convert_from_cdag_as_hyperdag(const Graph_t& dag);
7173

7274
private:
73-
unsigned Num_vertices = 0, Num_hyperedges = 0, Num_pins = 0;
75+
index_type Num_vertices = 0, Num_hyperedges = 0, Num_pins = 0;
7476

75-
std::vector<int> vertex_work_weights;
76-
std::vector<int> vertex_memory_weights;
77-
std::vector<int> hyperedge_weights;
77+
std::vector<workw_type> vertex_work_weights;
78+
std::vector<memw_type> vertex_memory_weights;
79+
std::vector<commw_type> hyperedge_weights;
7880

79-
std::vector<std::vector<unsigned>> incident_hyperedges_to_vertex;
80-
std::vector<std::vector<unsigned>> vertices_in_hyperedge;
81+
std::vector<std::vector<index_type>> incident_hyperedges_to_vertex;
82+
std::vector<std::vector<index_type>> vertices_in_hyperedge;
8183
};
8284

83-
void Hypergraph::add_pin(unsigned vertex_idx, unsigned hyperedge_idx)
85+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
86+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_pin(index_type vertex_idx, index_type hyperedge_idx)
8487
{
8588
if(vertex_idx >= Num_vertices)
8689
{
@@ -97,72 +100,81 @@ void Hypergraph::add_pin(unsigned vertex_idx, unsigned hyperedge_idx)
97100
}
98101
}
99102

100-
void Hypergraph::add_vertex(int work_weight, int memory_weight)
103+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
104+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_vertex(workw_type work_weight, memw_type memory_weight)
101105
{
102106
vertex_work_weights.push_back(work_weight);
103107
vertex_memory_weights.push_back(memory_weight);
104108
incident_hyperedges_to_vertex.emplace_back();
105109
++Num_vertices;
106110
}
107111

108-
void Hypergraph::add_empty_hyperedge(int weight)
112+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
113+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_empty_hyperedge(commw_type weight)
109114
{
110115
vertices_in_hyperedge.emplace_back();
111116
hyperedge_weights.push_back(weight);
112117
++Num_hyperedges;
113118
}
114119

115-
void Hypergraph::add_hyperedge(const std::vector<unsigned>& pins, int weight)
120+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
121+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::add_hyperedge(const std::vector<index_type>& pins, commw_type weight)
116122
{
117123
vertices_in_hyperedge.emplace_back(pins);
118124
hyperedge_weights.push_back(weight);
119-
for(unsigned vertex : pins)
125+
for(index_type vertex : pins)
120126
incident_hyperedges_to_vertex[vertex].push_back(Num_hyperedges);
121127
++Num_hyperedges;
122-
Num_pins += static_cast<unsigned>(pins.size());
128+
Num_pins += static_cast<index_type>(pins.size());
123129
}
124130

125-
void Hypergraph::set_vertex_work_weight(unsigned vertex_idx, int weight)
131+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
132+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_vertex_work_weight(index_type vertex_idx, workw_type weight)
126133
{
127134
if(vertex_idx >= Num_vertices)
128135
throw std::invalid_argument("Invalid Argument while setting vertex weight: vertex index out of range.");
129136
else
130137
vertex_work_weights[vertex_idx] = weight;
131138
}
132139

133-
void Hypergraph::set_vertex_memory_weight(unsigned vertex_idx, int weight)
140+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
141+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_vertex_memory_weight(index_type vertex_idx, memw_type weight)
134142
{
135143
if(vertex_idx >= Num_vertices)
136144
throw std::invalid_argument("Invalid Argument while setting vertex weight: vertex index out of range.");
137145
else
138146
vertex_memory_weights[vertex_idx] = weight;
139147
}
140148

141-
void Hypergraph::set_hyperedge_weight(unsigned hyperedge_idx, int weight)
149+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
150+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::set_hyperedge_weight(index_type hyperedge_idx, commw_type weight)
142151
{
143152
if(hyperedge_idx >= Num_hyperedges)
144153
throw std::invalid_argument("Invalid Argument while setting hyperedge weight: hyepredge index out of range.");
145154
else
146155
hyperedge_weights[hyperedge_idx] = weight;
147156
}
148157

149-
int Hypergraph::compute_total_vertex_work_weight() const
158+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
159+
workw_type Hypergraph<index_type, workw_type, memw_type, commw_type>::compute_total_vertex_work_weight() const
150160
{
151-
int total = 0;
152-
for(unsigned node = 0; node < Num_vertices; ++node)
161+
workw_type total = 0;
162+
for(index_type node = 0; node < Num_vertices; ++node)
153163
total += vertex_work_weights[node];
154164
return total;
155165
}
156166

157-
int Hypergraph::compute_total_vertex_memory_weight() const
167+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
168+
memw_type Hypergraph<index_type, workw_type, memw_type, commw_type>::compute_total_vertex_memory_weight() const
158169
{
159-
int total = 0;
160-
for(unsigned node = 0; node < Num_vertices; ++node)
170+
memw_type total = 0;
171+
for(index_type node = 0; node < Num_vertices; ++node)
161172
total += vertex_memory_weights[node];
162173
return total;
163174
}
164175

165-
void Hypergraph::clear()
176+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
177+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::clear()
166178
{
167179
Num_vertices = 0;
168180
Num_hyperedges = 0;
@@ -175,7 +187,8 @@ void Hypergraph::clear()
175187
vertices_in_hyperedge.clear();
176188
}
177189

178-
void Hypergraph::reset(unsigned num_vertices_, unsigned num_hyperedges_)
190+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
191+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::reset(index_type num_vertices_, index_type num_hyperedges_)
179192
{
180193
clear();
181194

@@ -189,33 +202,48 @@ void Hypergraph::reset(unsigned num_vertices_, unsigned num_hyperedges_)
189202
vertices_in_hyperedge.resize(num_hyperedges_);
190203
}
191204

205+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
192206
template<typename Graph_t>
193-
void Hypergraph::convert_from_cdag_as_dag(const Graph_t& dag)
207+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::convert_from_cdag_as_dag(const Graph_t& dag)
194208
{
195-
reset(static_cast<unsigned>(dag.num_vertices()), 0);
209+
static_assert(std::is_same_v<vertex_idx_t<Graph_t>, index_type>, "Index type mismatch, cannot convert DAG to hypergraph.");
210+
static_assert(std::is_same_v<v_workw_t<Graph_t>, workw_type>, "Work weight type mismatch, cannot convert DAG to hypergraph.");
211+
static_assert(std::is_same_v<v_memw_t<Graph_t>, memw_type>, "Memory weight type mismatch, cannot convert DAG to hypergraph.");
212+
static_assert(!has_edge_weights_v<Graph_t> || std::is_same_v<e_commw_t<Graph_t>, commw_type>, "Communication weight type mismatch, cannot convert DAG to hypergraph.");
213+
214+
reset(dag.num_vertices(), 0);
196215
for(const auto &node : dag.vertices())
197216
{
198-
set_vertex_work_weight(static_cast<unsigned>(node), static_cast<int>(dag.vertex_work_weight(node)));
199-
set_vertex_memory_weight(static_cast<unsigned>(node), static_cast<int>(dag.vertex_mem_weight(node)));
217+
set_vertex_work_weight(node, dag.vertex_work_weight(node));
218+
set_vertex_memory_weight(node, dag.vertex_mem_weight(node));
200219
for (const auto &child : dag.children(node))
201-
add_hyperedge({static_cast<unsigned>(node), static_cast<unsigned>(child)}); // TODO add edge weights if present
220+
if constexpr(has_edge_weights_v<Graph_t>)
221+
add_hyperedge({node, child}, dag.edge_comm_weight(edge_desc(node, child, dag).first));
222+
else
223+
add_hyperedge({node, child});
202224
}
203225
}
204226

227+
template<typename index_type, typename workw_type, typename memw_type, typename commw_type>
205228
template<typename Graph_t>
206-
void Hypergraph::convert_from_cdag_as_hyperdag(const Graph_t& dag)
229+
void Hypergraph<index_type, workw_type, memw_type, commw_type>::convert_from_cdag_as_hyperdag(const Graph_t& dag)
207230
{
208-
reset(static_cast<unsigned>(dag.num_vertices()), 0);
231+
static_assert(std::is_same_v<vertex_idx_t<Graph_t>, index_type>, "Index type mismatch, cannot convert DAG to hypergraph.");
232+
static_assert(std::is_same_v<v_workw_t<Graph_t>, workw_type>, "Work weight type mismatch, cannot convert DAG to hypergraph.");
233+
static_assert(std::is_same_v<v_memw_t<Graph_t>, memw_type>, "Memory weight type mismatch, cannot convert DAG to hypergraph.");
234+
static_assert(std::is_same_v<v_commw_t<Graph_t>, commw_type>, "Communication weight type mismatch, cannot convert DAG to hypergraph.");
235+
236+
reset(dag.num_vertices(), 0);
209237
for(const auto &node : dag.vertices())
210238
{
211-
set_vertex_work_weight(static_cast<unsigned>(node), static_cast<int>(dag.vertex_work_weight(node)));
212-
set_vertex_memory_weight(static_cast<unsigned>(node), static_cast<int>(dag.vertex_mem_weight(node)));
239+
set_vertex_work_weight(node, dag.vertex_work_weight(node));
240+
set_vertex_memory_weight(node, dag.vertex_mem_weight(node));
213241
if(dag.out_degree(node) == 0)
214242
continue;
215-
std::vector<unsigned> new_hyperedge({static_cast<unsigned>(node)});
243+
std::vector<index_type> new_hyperedge({node});
216244
for (const auto &child : dag.children(node))
217-
new_hyperedge.push_back(static_cast<unsigned>(child));
218-
add_hyperedge(new_hyperedge, static_cast<int>(dag.vertex_comm_weight(node)));
245+
new_hyperedge.push_back(child);
246+
add_hyperedge(new_hyperedge, dag.vertex_comm_weight(node));
219247
}
220248
}
221249

0 commit comments

Comments
 (0)