Skip to content

Commit ccece2d

Browse files
committed
Added parallel forward solve
1 parent ad3eec3 commit ccece2d

File tree

12 files changed

+313
-160
lines changed

12 files changed

+313
-160
lines changed

cmake/sources.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ set(factor_highs_sources
214214
ipm/hipo/factorhighs/FormatHandler.cpp
215215
ipm/hipo/factorhighs/HybridHybridFormatHandler.cpp
216216
ipm/hipo/factorhighs/HybridSolveHandler.cpp
217-
ipm/hipo/factorhighs/KrylovMethodsIpm.cpp
218217
ipm/hipo/factorhighs/Numeric.cpp
219218
ipm/hipo/factorhighs/SolveHandler.cpp
220219
ipm/hipo/factorhighs/Swaps.cpp
@@ -233,7 +232,6 @@ set(factor_highs_headers
233232
ipm/hipo/factorhighs/FormatHandler.h
234233
ipm/hipo/factorhighs/HybridHybridFormatHandler.h
235234
ipm/hipo/factorhighs/HybridSolveHandler.h
236-
ipm/hipo/factorhighs/KrylovMethodsIpm.h
237235
ipm/hipo/factorhighs/Numeric.h
238236
ipm/hipo/factorhighs/ReturnValues.h
239237
ipm/hipo/factorhighs/SolveHandler.h

highs/ipm/hipo/auxiliary/Auxiliary.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,4 +248,24 @@ double Clock::stop() const {
248248
return d.count();
249249
}
250250

251+
TaskGroupSpecial::~TaskGroupSpecial() {
252+
// Using TaskGroup may throw an exception when tasks are cancelled. Not sure
253+
// exactly why this happens, but for now this fix seems to work.
254+
255+
// No virtual destructor in TaskGroup. Do not call this class via pointer to
256+
// the base!
257+
258+
cancel();
259+
260+
// re-call taskWait if it throws, until it succeeds
261+
while (true) {
262+
try {
263+
taskWait();
264+
break;
265+
} catch (HighsTask::Interrupt) {
266+
continue;
267+
}
268+
}
269+
}
270+
251271
} // namespace hipo

highs/ipm/hipo/auxiliary/Auxiliary.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#include "ipm/hipo/auxiliary/IntConfig.h"
11+
#include "parallel/HighsParallel.h"
1112

1213
namespace hipo {
1314

@@ -66,6 +67,11 @@ class Clock {
6667
double stop() const;
6768
};
6869

70+
class TaskGroupSpecial : public highs::parallel::TaskGroup {
71+
public:
72+
~TaskGroupSpecial();
73+
};
74+
6975
} // namespace hipo
7076

7177
#endif

highs/ipm/hipo/factorhighs/FactorHiGHS.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ Int FHsolver::factorise(const Symbolic& S, const std::vector<Int>& rows,
4848
return fact_obj.run(N_);
4949
}
5050

51-
Int FHsolver::solve(std::vector<double>& x) { return N_.solve(x); }
51+
Int FHsolver::solve(std::vector<double>& x) {
52+
N_.setup();
53+
return N_.solve(x);
54+
}
5255

5356
void FHsolver::getRegularisation(std::vector<double>& reg) { N_.getReg(reg); }
5457

highs/ipm/hipo/factorhighs/Factorise.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -156,29 +156,6 @@ void Factorise::permute(const std::vector<Int>& iperm) {
156156
valA_ = std::move(new_val);
157157
}
158158

159-
class TaskGroupSpecial : public highs::parallel::TaskGroup {
160-
// Using TaskGroup may throw an exception when tasks are cancelled. Not sure
161-
// exactly why this happens, but for now this fix seems to work.
162-
163-
public:
164-
~TaskGroupSpecial() {
165-
// No virtual destructor in TaskGroup. Do not call this class via pointer to
166-
// the base!
167-
168-
cancel();
169-
170-
// re-call taskWait if it throws, until it succeeds
171-
while (true) {
172-
try {
173-
taskWait();
174-
break;
175-
} catch (HighsTask::Interrupt) {
176-
continue;
177-
}
178-
}
179-
}
180-
};
181-
182159
void Factorise::processSupernode(Int sn) {
183160
// Assemble frontal matrix for supernode sn, perform partial factorisation and
184161
// store the result.

highs/ipm/hipo/factorhighs/HybridSolveHandler.cpp

Lines changed: 220 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,18 @@ namespace hipo {
1212
HybridSolveHandler::HybridSolveHandler(
1313
const Symbolic& S, const std::vector<std::vector<double>>& sn_columns,
1414
const std::vector<std::vector<Int>>& swaps,
15-
const std::vector<std::vector<double>>& pivot_2x2)
16-
: SolveHandler(S, sn_columns), swaps_{swaps}, pivot_2x2_{pivot_2x2} {}
15+
const std::vector<std::vector<double>>& pivot_2x2,
16+
const std::vector<Int>& fc, const std::vector<Int>& nc,
17+
const std::vector<Int>& fcr, const std::vector<Int>& ncr,
18+
std::vector<std::vector<double>>& local)
19+
: SolveHandler(S, sn_columns),
20+
swaps_{swaps},
21+
pivot_2x2_{pivot_2x2},
22+
first_child_{fc},
23+
next_child_{nc},
24+
first_child_reverse_{fcr},
25+
next_child_reverse_{ncr},
26+
local_{local} {}
1727

1828
void HybridSolveHandler::forwardSolve(std::vector<double>& x) const {
1929
// Forward solve.
@@ -314,4 +324,212 @@ void HybridSolveHandler::diagSolve(std::vector<double>& x) const {
314324
}
315325
}
316326

327+
void HybridSolveHandler::parForwardSolve(std::vector<double>& x) {
328+
if (S_.parTree()) {
329+
TaskGroupSpecial tg;
330+
331+
for (Int sn = 0; sn < S_.sn(); ++sn) {
332+
if (S_.snParent(sn) == -1) spawnNode(sn, x, tg);
333+
}
334+
335+
tg.taskWait();
336+
337+
} else {
338+
for (Int sn = 0; sn < S_.sn(); ++sn) {
339+
processSupernode(sn, x, false);
340+
}
341+
}
342+
343+
for (int sn = 0; sn < S_.sn(); ++sn) {
344+
const int sn_size = S_.snStart(sn + 1) - S_.snStart(sn);
345+
std::memcpy(&x[S_.snStart(sn)], local_[sn].data(),
346+
sn_size * sizeof(double));
347+
}
348+
}
349+
350+
void HybridSolveHandler::spawnNode(Int sn, const std::vector<double>& x,
351+
const TaskGroupSpecial& tg, bool do_spawn) {
352+
// if do_spawn is true, a task is actually spawned, otherwise, it is executed
353+
// immediately. This avoids the overhead of spawning a task if a supernode has
354+
// a single child.
355+
356+
const NodeData* ptr = S_.nodeDataPtr(sn);
357+
if (!ptr) return;
358+
359+
if (ptr->type == NodeType::single) {
360+
// sn is single node; spawn only that
361+
362+
auto f = [this, &x, sn]() { processSupernode(sn, x, true); };
363+
364+
if (do_spawn)
365+
tg.spawn(std::move(f));
366+
else
367+
f();
368+
369+
} else {
370+
// sn is head of the first subtree in a group of small subtrees; spawn all
371+
// of them
372+
373+
auto f = [this, &x, ptr]() {
374+
for (Int i = 0; i < ptr->group.size(); ++i) {
375+
Int st_head = ptr->group[i];
376+
Int start = ptr->firstdesc[i];
377+
Int end = st_head + 1;
378+
for (Int sn = start; sn < end; ++sn) {
379+
processSupernode(sn, x, false);
380+
}
381+
}
382+
};
383+
384+
if (do_spawn)
385+
tg.spawn(std::move(f));
386+
else
387+
f();
388+
}
389+
}
390+
391+
void HybridSolveHandler::syncNode(Int sn, const TaskGroupSpecial& tg) {
392+
// If spawnNode(sn,tg) created a task, then sync it.
393+
// This happens only if sn is found in the treeSplitting data structure.
394+
395+
if (S_.nodeDataPtr(sn)) tg.sync();
396+
}
397+
398+
void HybridSolveHandler::processSupernode(Int sn, const std::vector<double>& x,
399+
bool parallelise) {
400+
// Parallel forward solve.
401+
// Blas calls: dtrsv, dgemv
402+
403+
// supernode columns in format FH
404+
405+
#if HIPO_TIMING_LEVEL >= 2
406+
Clock clock;
407+
#endif
408+
409+
TaskGroupSpecial tg;
410+
411+
if (parallelise) {
412+
// if there is only one child, do not parallelise
413+
if (first_child_[sn] != -1 && next_child_[first_child_[sn]] == -1) {
414+
spawnNode(first_child_[sn], x, tg, false);
415+
parallelise = false;
416+
} else {
417+
// spawn children of this supernode in reverse order
418+
int child_to_spawn = first_child_reverse_[sn];
419+
while (child_to_spawn != -1) {
420+
spawnNode(child_to_spawn, x, tg);
421+
child_to_spawn = next_child_reverse_[child_to_spawn];
422+
}
423+
}
424+
}
425+
426+
const Int nb = S_.blockSize();
427+
428+
// leading size of supernode
429+
const Int ldSn = S_.ptr(sn + 1) - S_.ptr(sn);
430+
431+
// number of columns in the supernode
432+
const Int sn_size = S_.snStart(sn + 1) - S_.snStart(sn);
433+
434+
// first colums of the supernode
435+
const Int sn_start = S_.snStart(sn);
436+
437+
// index to access S->rows for this supernode
438+
const Int start_row = S_.ptr(sn);
439+
440+
// number of blocks of columns
441+
const Int n_blocks = (sn_size - 1) / nb + 1;
442+
443+
// index to access snColumns[sn]
444+
Int SnCol_ind{};
445+
446+
// initialize local storage for this supernode
447+
double* local = local_[sn].data();
448+
std::memset(local, 0, local_[sn].size() * sizeof(double));
449+
450+
// contribution from original vector
451+
std::memcpy(local, &x[sn_start], sn_size * sizeof(double));
452+
453+
// contributions from children
454+
int child = first_child_[sn];
455+
while (child != -1) {
456+
if (parallelise) {
457+
// wait for child to be ready
458+
syncNode(child, tg);
459+
}
460+
461+
std::vector<double>& child_x = local_[child];
462+
const int child_size = S_.snStart(child + 1) - S_.snStart(child);
463+
464+
// number of entries to assemble into local
465+
const int nc = child_x.size() - child_size;
466+
467+
#if HIPO_TIMING_LEVEL >= 2
468+
clock.start();
469+
#endif
470+
// assemble each contribution of this child
471+
for (int i = 0; i < nc; ++i) {
472+
const int j = S_.relindClique(child, i);
473+
local[j] += child_x[child_size + i];
474+
}
475+
#if HIPO_TIMING_LEVEL >= 2
476+
if (data_) data_->sumTime(kTimeSolveSolve_sparse, clock.stop());
477+
#endif
478+
479+
child = next_child_[child];
480+
}
481+
482+
// go through blocks of columns for this supernode
483+
for (int j = 0; j < n_blocks; ++j) {
484+
// number of columns in the block
485+
const int jb = std::min(nb, sn_size - nb * j);
486+
487+
// number of entries in diagonal part
488+
const int diag_entries = jb * jb;
489+
490+
// index to access vector x
491+
const int x_start = sn_start + nb * j;
492+
493+
#ifdef HIPO_PIVOTING
494+
#if HIPO_TIMING_LEVEL >= 2
495+
clock.start();
496+
#endif
497+
// apply swaps to portion of rhs that is affected
498+
const int* current_swaps = &swaps_[sn][nb * j];
499+
permuteWithSwaps(&local[nb * j], current_swaps, jb);
500+
#if HIPO_TIMING_LEVEL >= 2
501+
if (data_) data_->sumTime(kTimeSolveSolve_swap, clock.stop());
502+
#endif
503+
#endif
504+
505+
#if HIPO_TIMING_LEVEL >= 2
506+
clock.start();
507+
#endif
508+
callAndTime_dtrsv('U', 'T', 'U', jb, &sn_columns_[sn][SnCol_ind], jb,
509+
&local[nb * j], 1, *data_);
510+
511+
SnCol_ind += diag_entries;
512+
513+
const int gemv_size = ldSn - nb * j - jb;
514+
515+
callAndTime_dgemv('T', jb, gemv_size, -1.0, &sn_columns_[sn][SnCol_ind], jb,
516+
&local[nb * j], 1, 1.0, &local[nb * j + jb], 1, *data_);
517+
SnCol_ind += jb * gemv_size;
518+
#if HIPO_TIMING_LEVEL >= 2
519+
if (data_) data_->sumTime(kTimeSolveSolve_dense, clock.stop());
520+
#endif
521+
522+
#ifdef HIPO_PIVOTING
523+
#if HIPO_TIMING_LEVEL >= 2
524+
clock.start();
525+
#endif
526+
// apply inverse swaps
527+
permuteWithSwaps(&local[nb * j], current_swaps, jb, true);
528+
#if HIPO_TIMING_LEVEL >= 2
529+
if (data_) data_->sumTime(kTimeSolveSolve_swap, clock.stop());
530+
#endif
531+
#endif
532+
}
533+
}
534+
317535
} // namespace hipo

highs/ipm/hipo/factorhighs/HybridSolveHandler.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,40 @@
22
#define FACTORHIGHS_HYBRID_SOLVE_HANDLER_H
33

44
#include "SolveHandler.h"
5+
#include "ipm/hipo/auxiliary/Auxiliary.h"
56

67
namespace hipo {
78

89
class HybridSolveHandler : public SolveHandler {
910
const std::vector<std::vector<Int>>& swaps_;
1011
const std::vector<std::vector<double>>& pivot_2x2_;
1112

13+
const std::vector<Int>& first_child_;
14+
const std::vector<Int>& next_child_;
15+
const std::vector<Int>& first_child_reverse_;
16+
const std::vector<Int>& next_child_reverse_;
17+
18+
std::vector<std::vector<double>>& local_;
19+
1220
void forwardSolve(std::vector<double>& x) const override;
1321
void backwardSolve(std::vector<double>& x) const override;
1422
void diagSolve(std::vector<double>& x) const override;
1523

24+
void parForwardSolve(std::vector<double>& x) override;
25+
26+
void processSupernode(Int sn, const std::vector<double>& x, bool parallelise);
27+
void spawnNode(Int sn, const std::vector<double>& x,
28+
const TaskGroupSpecial& tg, bool do_spawn = true);
29+
void syncNode(Int sn, const TaskGroupSpecial& tg);
30+
1631
public:
1732
HybridSolveHandler(const Symbolic& S,
1833
const std::vector<std::vector<double>>& sn_columns,
1934
const std::vector<std::vector<Int>>& swaps,
20-
const std::vector<std::vector<double>>& pivot_2x2);
35+
const std::vector<std::vector<double>>& pivot_2x2,
36+
const std::vector<Int>& fc, const std::vector<Int>& nc,
37+
const std::vector<Int>& fcr, const std::vector<Int>& ncr,
38+
std::vector<std::vector<double>>& local);
2139
};
2240

2341
} // namespace hipo

0 commit comments

Comments
 (0)