@@ -331,24 +331,19 @@ class BspScheduleCS : public BspSchedule<Graph_t> {
331331 node_to_proc_been_sent[node][BspSchedule<Graph_t>::node_to_processor_assignment[node]] = true ;
332332 }
333333
334- // processor, ordered list of (cost, node, to_processor)
335- std::vector<std::set<std::vector<vertex_idx_t <Graph_t>>, std::greater<>>> require_sending (
336- BspSchedule<Graph_t>::instance->numberOfProcessors ());
337- // TODO the datastructure seems to be wrong. the vectors added to the set have elements of different types.
338- // it should really be std::vector<std::set<std::tuple<v_commw_t<Graph_t>, vertex_idx_t<Graph_t>,
339- // vertex_idx_t<Graph_t>>>> added many static_cast below as tmp fix
334+ // The data structure stores for each processor a set of tuples representing required sends.
335+ // Each tuple is (communication_cost, source_node, destination_processor).
336+ std::vector<std::set<std::tuple<v_commw_t <Graph_t>, vertex_idx_t <Graph_t>, unsigned >>> require_sending (BspSchedule<Graph_t>::instance->numberOfProcessors ());
340337
341338 for (unsigned proc = 0 ; proc < BspSchedule<Graph_t>::instance->numberOfProcessors (); proc++) {
342339 for (const auto &node : step_proc_node_list[0 ][proc]) {
343340
344341 for (const auto &target : BspSchedule<Graph_t>::instance->getComputationalDag ().children (node)) {
345342 if (proc != BspSchedule<Graph_t>::assignedProcessor (target)) {
346343 require_sending[proc].insert (
347- {static_cast <vertex_idx_t <Graph_t>>(
348- BspSchedule<Graph_t>::instance->getComputationalDag ().vertex_comm_weight (node) *
349- BspSchedule<Graph_t>::instance->getArchitecture ().sendCosts (
350- proc, BspSchedule<Graph_t>::node_to_processor_assignment[target])),
351- node, BspSchedule<Graph_t>::node_to_processor_assignment[target]});
344+ {BspSchedule<Graph_t>::instance->getComputationalDag ().vertex_comm_weight (node) * BspSchedule<Graph_t>::instance->getArchitecture ().sendCosts (proc, BspSchedule<Graph_t>::node_to_processor_assignment[target]),
345+ node,
346+ BspSchedule<Graph_t>::node_to_processor_assignment[target]});
352347 }
353348 }
354349 }
@@ -374,8 +369,8 @@ class BspScheduleCS : public BspSchedule<Graph_t> {
374369 BspSchedule<Graph_t>::instance->getComputationalDag ().vertex_comm_weight (source) *
375370 BspSchedule<Graph_t>::instance->getArchitecture ().sendCosts (
376371 BspSchedule<Graph_t>::node_to_processor_assignment[source], proc);
377- require_sending[BspSchedule<Graph_t>::assignedProcessor ( source) ].erase (
378- {static_cast < vertex_idx_t <Graph_t>>( comm_cost) , source, proc});
372+ require_sending[BspSchedule<Graph_t>::node_to_processor_assignment[ source] ].erase (
373+ {comm_cost, source, proc});
379374 send_cost[BspSchedule<Graph_t>::node_to_processor_assignment[source]] += comm_cost;
380375 receive_cost[proc] += comm_cost;
381376 }
@@ -394,22 +389,23 @@ class BspScheduleCS : public BspSchedule<Graph_t> {
394389 // TODO: permute the order of processors
395390 for (size_t proc = 0 ; proc < BspSchedule<Graph_t>::instance->numberOfProcessors (); proc++) {
396391 if (require_sending[proc].empty () ||
397- static_cast < v_commw_t <Graph_t>>((*( require_sending[proc].rbegin ()))[ 0 ] ) + send_cost[proc] >
392+ std::get< 0 >(* require_sending[proc].rbegin ()) + send_cost[proc] >
398393 max_comm_cost)
399394 continue ;
400395 auto iter = require_sending[proc].begin ();
401396 while (iter != require_sending[proc].cend ()) {
402- if (static_cast <v_commw_t <Graph_t>>((*iter)[0 ]) + send_cost[proc] > max_comm_cost ||
403- static_cast <v_commw_t <Graph_t>>((*iter)[0 ]) + receive_cost[(*iter)[2 ]] > max_comm_cost) {
397+ const auto & [comm_cost, node_to_send, dest_proc] = *iter;
398+ if (comm_cost + send_cost[proc] > max_comm_cost ||
399+ comm_cost + receive_cost[dest_proc] > max_comm_cost) {
404400 iter++;
405401 } else {
406- commSchedule.emplace (std::make_tuple ((*iter)[ 1 ] , proc, (*iter)[ 2 ] ), step - 1 );
407- node_to_proc_been_sent[(*iter)[ 1 ]][(*iter)[ 2 ] ] = true ;
408- send_cost[proc] += static_cast < v_commw_t <Graph_t>>((*iter)[ 0 ]) ;
409- receive_cost[(*iter)[ 2 ]] += static_cast < v_commw_t <Graph_t>>((*iter)[ 0 ]) ;
402+ commSchedule.emplace (std::make_tuple (node_to_send , proc, dest_proc ), step - 1 );
403+ node_to_proc_been_sent[node_to_send][dest_proc ] = true ;
404+ send_cost[proc] += comm_cost ;
405+ receive_cost[dest_proc] += comm_cost ;
410406 iter = require_sending[proc].erase (iter);
411407 if (require_sending[proc].empty () ||
412- static_cast < v_commw_t <Graph_t>>((*( require_sending[proc].rbegin ()))[ 0 ] ) + send_cost[proc] >
408+ std::get< 0 >(* require_sending[proc].rbegin ()) + send_cost[proc] >
413409 max_comm_cost)
414410 break ;
415411 }
@@ -423,10 +419,9 @@ class BspScheduleCS : public BspSchedule<Graph_t> {
423419 for (const auto &target : BspSchedule<Graph_t>::instance->getComputationalDag ().children (node))
424420 if (proc != BspSchedule<Graph_t>::assignedProcessor (target)) {
425421 require_sending[proc].insert (
426- {static_cast <vertex_idx_t <Graph_t>>(
427- BspSchedule<Graph_t>::instance->getComputationalDag ().vertex_comm_weight (node) *
422+ {BspSchedule<Graph_t>::instance->getComputationalDag ().vertex_comm_weight (node) *
428423 BspSchedule<Graph_t>::instance->getArchitecture ().sendCosts (
429- proc, BspSchedule<Graph_t>::node_to_processor_assignment[target])) ,
424+ proc, BspSchedule<Graph_t>::node_to_processor_assignment[target]),
430425 node, BspSchedule<Graph_t>::node_to_processor_assignment[target]});
431426 }
432427 }
0 commit comments