@@ -12,8 +12,18 @@ namespace hipo {
1212HybridSolveHandler::HybridSolveHandler (
1313 const Symbolic& S, const std::vector<std::vector<double >>& sn_columns,
1414 const std::vector<std::vector<Int>>& swaps,
15- const std::vector<std::vector<double >>& pivot_2x2)
16- : SolveHandler(S, sn_columns), swaps_{swaps}, pivot_2x2_{pivot_2x2} {}
15+ const std::vector<std::vector<double >>& pivot_2x2,
16+ const std::vector<Int>& fc, const std::vector<Int>& nc,
17+ const std::vector<Int>& fcr, const std::vector<Int>& ncr,
18+ std::vector<std::vector<double >>& local)
19+ : SolveHandler(S, sn_columns),
20+ swaps_{swaps},
21+ pivot_2x2_{pivot_2x2},
22+ first_child_{fc},
23+ next_child_{nc},
24+ first_child_reverse_{fcr},
25+ next_child_reverse_{ncr},
26+ local_{local} {}
1727
1828void HybridSolveHandler::forwardSolve (std::vector<double >& x) const {
1929 // Forward solve.
@@ -314,4 +324,212 @@ void HybridSolveHandler::diagSolve(std::vector<double>& x) const {
314324 }
315325}
316326
327+ void HybridSolveHandler::parForwardSolve (std::vector<double >& x) {
328+ if (S_.parTree ()) {
329+ TaskGroupSpecial tg;
330+
331+ for (Int sn = 0 ; sn < S_.sn (); ++sn) {
332+ if (S_.snParent (sn) == -1 ) spawnNode (sn, x, tg);
333+ }
334+
335+ tg.taskWait ();
336+
337+ } else {
338+ for (Int sn = 0 ; sn < S_.sn (); ++sn) {
339+ processSupernode (sn, x, false );
340+ }
341+ }
342+
343+ for (int sn = 0 ; sn < S_.sn (); ++sn) {
344+ const int sn_size = S_.snStart (sn + 1 ) - S_.snStart (sn);
345+ std::memcpy (&x[S_.snStart (sn)], local_[sn].data (),
346+ sn_size * sizeof (double ));
347+ }
348+ }
349+
350+ void HybridSolveHandler::spawnNode (Int sn, const std::vector<double >& x,
351+ const TaskGroupSpecial& tg, bool do_spawn) {
352+ // if do_spawn is true, a task is actually spawned, otherwise, it is executed
353+ // immediately. This avoids the overhead of spawning a task if a supernode has
354+ // a single child.
355+
356+ const NodeData* ptr = S_.nodeDataPtr (sn);
357+ if (!ptr) return ;
358+
359+ if (ptr->type == NodeType::single) {
360+ // sn is single node; spawn only that
361+
362+ auto f = [this , &x, sn]() { processSupernode (sn, x, true ); };
363+
364+ if (do_spawn)
365+ tg.spawn (std::move (f));
366+ else
367+ f ();
368+
369+ } else {
370+ // sn is head of the first subtree in a group of small subtrees; spawn all
371+ // of them
372+
373+ auto f = [this , &x, ptr]() {
374+ for (Int i = 0 ; i < ptr->group .size (); ++i) {
375+ Int st_head = ptr->group [i];
376+ Int start = ptr->firstdesc [i];
377+ Int end = st_head + 1 ;
378+ for (Int sn = start; sn < end; ++sn) {
379+ processSupernode (sn, x, false );
380+ }
381+ }
382+ };
383+
384+ if (do_spawn)
385+ tg.spawn (std::move (f));
386+ else
387+ f ();
388+ }
389+ }
390+
391+ void HybridSolveHandler::syncNode (Int sn, const TaskGroupSpecial& tg) {
392+ // If spawnNode(sn,tg) created a task, then sync it.
393+ // This happens only if sn is found in the treeSplitting data structure.
394+
395+ if (S_.nodeDataPtr (sn)) tg.sync ();
396+ }
397+
398+ void HybridSolveHandler::processSupernode (Int sn, const std::vector<double >& x,
399+ bool parallelise) {
400+ // Parallel forward solve.
401+ // Blas calls: dtrsv, dgemv
402+
403+ // supernode columns in format FH
404+
405+ #if HIPO_TIMING_LEVEL >= 2
406+ Clock clock;
407+ #endif
408+
409+ TaskGroupSpecial tg;
410+
411+ if (parallelise) {
412+ // if there is only one child, do not parallelise
413+ if (first_child_[sn] != -1 && next_child_[first_child_[sn]] == -1 ) {
414+ spawnNode (first_child_[sn], x, tg, false );
415+ parallelise = false ;
416+ } else {
417+ // spawn children of this supernode in reverse order
418+ int child_to_spawn = first_child_reverse_[sn];
419+ while (child_to_spawn != -1 ) {
420+ spawnNode (child_to_spawn, x, tg);
421+ child_to_spawn = next_child_reverse_[child_to_spawn];
422+ }
423+ }
424+ }
425+
426+ const Int nb = S_.blockSize ();
427+
428+ // leading size of supernode
429+ const Int ldSn = S_.ptr (sn + 1 ) - S_.ptr (sn);
430+
431+ // number of columns in the supernode
432+ const Int sn_size = S_.snStart (sn + 1 ) - S_.snStart (sn);
433+
434+ // first colums of the supernode
435+ const Int sn_start = S_.snStart (sn);
436+
437+ // index to access S->rows for this supernode
438+ const Int start_row = S_.ptr (sn);
439+
440+ // number of blocks of columns
441+ const Int n_blocks = (sn_size - 1 ) / nb + 1 ;
442+
443+ // index to access snColumns[sn]
444+ Int SnCol_ind{};
445+
446+ // initialize local storage for this supernode
447+ double * local = local_[sn].data ();
448+ std::memset (local, 0 , local_[sn].size () * sizeof (double ));
449+
450+ // contribution from original vector
451+ std::memcpy (local, &x[sn_start], sn_size * sizeof (double ));
452+
453+ // contributions from children
454+ int child = first_child_[sn];
455+ while (child != -1 ) {
456+ if (parallelise) {
457+ // wait for child to be ready
458+ syncNode (child, tg);
459+ }
460+
461+ std::vector<double >& child_x = local_[child];
462+ const int child_size = S_.snStart (child + 1 ) - S_.snStart (child);
463+
464+ // number of entries to assemble into local
465+ const int nc = child_x.size () - child_size;
466+
467+ #if HIPO_TIMING_LEVEL >= 2
468+ clock.start ();
469+ #endif
470+ // assemble each contribution of this child
471+ for (int i = 0 ; i < nc; ++i) {
472+ const int j = S_.relindClique (child, i);
473+ local[j] += child_x[child_size + i];
474+ }
475+ #if HIPO_TIMING_LEVEL >= 2
476+ if (data_) data_->sumTime (kTimeSolveSolve_sparse , clock.stop ());
477+ #endif
478+
479+ child = next_child_[child];
480+ }
481+
482+ // go through blocks of columns for this supernode
483+ for (int j = 0 ; j < n_blocks; ++j) {
484+ // number of columns in the block
485+ const int jb = std::min (nb, sn_size - nb * j);
486+
487+ // number of entries in diagonal part
488+ const int diag_entries = jb * jb;
489+
490+ // index to access vector x
491+ const int x_start = sn_start + nb * j;
492+
493+ #ifdef HIPO_PIVOTING
494+ #if HIPO_TIMING_LEVEL >= 2
495+ clock.start ();
496+ #endif
497+ // apply swaps to portion of rhs that is affected
498+ const int * current_swaps = &swaps_[sn][nb * j];
499+ permuteWithSwaps (&local[nb * j], current_swaps, jb);
500+ #if HIPO_TIMING_LEVEL >= 2
501+ if (data_) data_->sumTime (kTimeSolveSolve_swap , clock.stop ());
502+ #endif
503+ #endif
504+
505+ #if HIPO_TIMING_LEVEL >= 2
506+ clock.start ();
507+ #endif
508+ callAndTime_dtrsv (' U' , ' T' , ' U' , jb, &sn_columns_[sn][SnCol_ind], jb,
509+ &local[nb * j], 1 , *data_);
510+
511+ SnCol_ind += diag_entries;
512+
513+ const int gemv_size = ldSn - nb * j - jb;
514+
515+ callAndTime_dgemv (' T' , jb, gemv_size, -1.0 , &sn_columns_[sn][SnCol_ind], jb,
516+ &local[nb * j], 1 , 1.0 , &local[nb * j + jb], 1 , *data_);
517+ SnCol_ind += jb * gemv_size;
518+ #if HIPO_TIMING_LEVEL >= 2
519+ if (data_) data_->sumTime (kTimeSolveSolve_dense , clock.stop ());
520+ #endif
521+
522+ #ifdef HIPO_PIVOTING
523+ #if HIPO_TIMING_LEVEL >= 2
524+ clock.start ();
525+ #endif
526+ // apply inverse swaps
527+ permuteWithSwaps (&local[nb * j], current_swaps, jb, true );
528+ #if HIPO_TIMING_LEVEL >= 2
529+ if (data_) data_->sumTime (kTimeSolveSolve_swap , clock.stop ());
530+ #endif
531+ #endif
532+ }
533+ }
534+
317535} // namespace hipo
0 commit comments