1111#include " SymScaling.h"
1212#include " ipm/hipo/auxiliary/Auxiliary.h"
1313#include " ipm/hipo/auxiliary/Log.h"
14- #include " parallel/HighsParallel.h"
1514
1615namespace hipo {
1716
@@ -168,42 +167,40 @@ void Factorise::permute(const std::vector<Int>& iperm) {
168167 valA_ = std::move (new_val);
169168}
170169
171- class TaskGroupSpecial : public highs ::parallel::TaskGroup {
170+ TaskGroupSpecial::~TaskGroupSpecial () {
172171 // Using TaskGroup may throw an exception when tasks are cancelled. Not sure
173172 // exactly why this happens, but for now this fix seems to work.
174173
175- public:
176- ~TaskGroupSpecial () {
177- // No virtual destructor in TaskGroup. Do not call this class via pointer to
178- // the base!
174+ // No virtual destructor in TaskGroup. Do not call this class via pointer to
175+ // the base!
179176
180- cancel ();
177+ cancel ();
181178
182- // re-call taskWait if it throws, until it succeeds
183- while (true ) {
184- try {
185- taskWait ();
186- break ;
187- } catch (HighsTask::Interrupt) {
188- continue ;
189- }
179+ // re-call taskWait if it throws, until it succeeds
180+ while (true ) {
181+ try {
182+ taskWait ();
183+ break ;
184+ } catch (HighsTask::Interrupt) {
185+ continue ;
190186 }
191187 }
192- };
188+ }
193189
194- void Factorise::processSupernode (Int sn) {
190+ void Factorise::processSupernode (Int sn, bool parallelise ) {
195191 // Assemble frontal matrix for supernode sn, perform partial factorisation and
196192 // store the result.
197193
198194 TaskGroupSpecial tg;
199195
200196 if (flag_stop_) return ;
201197
202- if (S_. parTree () ) {
198+ if (parallelise ) {
203199 // spawn children of this supernode in reverse order
204200 Int child_to_spawn = first_child_reverse_[sn];
205201 while (child_to_spawn != -1 ) {
206- tg.spawn ([=]() { processSupernode (child_to_spawn); });
202+ if (spawnNode (child_to_spawn, tg)) return ;
203+
207204 child_to_spawn = next_child_reverse_[child_to_spawn];
208205 }
209206
@@ -263,7 +260,7 @@ void Factorise::processSupernode(Int sn) {
263260 // Schur contribution of the current child
264261 std::vector<double >& child_clique = schur_contribution_[child_sn];
265262
266- if (S_. parTree () ) {
263+ if (parallelise ) {
267264 // sync with spawned child, apart from the first one
268265 if (child_sn != first_child_[sn]) tg.sync ();
269266
@@ -376,6 +373,35 @@ void Factorise::processSupernode(Int sn) {
376373#endif
377374}
378375
376+ void Factorise::processSubtree (Int start, Int end) {
377+ for (Int sn = start; sn < end; ++sn) {
378+ processSupernode (sn, false );
379+ }
380+ }
381+
382+ bool Factorise::spawnNode (Int sn, const TaskGroupSpecial& tg) {
383+ auto it = S_.treeSplitting ().find (sn);
384+
385+ if (it == S_.treeSplitting ().end ()) {
386+ log_->printDevInfo (" Missing supernode from tree splitting\n " );
387+ flag_stop_ = true ;
388+ return true ;
389+ }
390+
391+ if (it->second .type == NodeType::single) {
392+ // sn is single node, spawn only that
393+ tg.spawn ([=]() { processSupernode (sn, true ); });
394+
395+ } else {
396+ // sn is subtree, spawn the whole subtree
397+ Int start = it->second .first ;
398+ Int end = sn + 1 ;
399+ tg.spawn ([=]() { processSubtree (start, end); });
400+ }
401+
402+ return false ;
403+ }
404+
379405bool Factorise::run (Numeric& num) {
380406#if HIPO_TIMING_LEVEL >= 1
381407 Clock clock;
@@ -395,12 +421,10 @@ bool Factorise::run(Numeric& num) {
395421 sn_columns_.resize (S_.sn ());
396422
397423 if (S_.parTree ()) {
398- Int spawned_roots{};
399424 // spawn tasks for root supernodes
400425 for (Int sn = 0 ; sn < S_.sn (); ++sn) {
401426 if (S_.snParent (sn) == -1 ) {
402- tg.spawn ([=]() { processSupernode (sn); });
403- ++spawned_roots;
427+ if (spawnNode (sn, tg)) return true ;
404428 }
405429 }
406430
@@ -409,7 +433,7 @@ bool Factorise::run(Numeric& num) {
409433 } else {
410434 // go through each supernode serially
411435 for (Int sn = 0 ; sn < S_.sn (); ++sn) {
412- processSupernode (sn);
436+ processSupernode (sn, false );
413437 }
414438 }
415439
0 commit comments