Skip to content

Commit 1cd0112

Browse files
Split up distribute method into smaller functions
1 parent 9a1b579 commit 1cd0112

File tree

2 files changed

+165
-149
lines changed

2 files changed

+165
-149
lines changed

Code/Source/solver/BoundaryCondition.cpp

Lines changed: 158 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ int BoundaryCondition::get_local_index(int global_node_id) const
249249

250250
void BoundaryCondition::distribute(const ComMod& com_mod, const CmMod& cm_mod, const cmType& cm, const faceType& face)
251251
{
252-
#define n_debug_bc_distribute
253-
#ifdef debug_bc_distribute
252+
#define n_debug_distribute
253+
#ifdef debug_distribute
254254
DebugMsg dmsg(__func__, cm.idcm());
255255
dmsg << "Distributing BC data" << std::endl;
256256
#endif
@@ -265,9 +265,24 @@ void BoundaryCondition::distribute(const ComMod& com_mod, const CmMod& cm_mod, c
265265

266266
// Number of nodes on the face on this processor
267267
local_num_nodes_ = face_->nNo;
268+
const bool is_slave = cm.slv(cm_mod);
269+
distribute_metadata(cm_mod, cm, is_slave);
270+
if (spatially_variable) {
271+
distribute_spatially_variable(com_mod, cm_mod, cm, is_slave);
272+
} else {
273+
distribute_uniform(cm_mod, cm, is_slave);
274+
}
275+
distribute_flags(cm_mod, cm, is_slave);
276+
defined_ = true;
268277

269-
bool is_slave = cm.slv(cm_mod);
278+
#ifdef debug_distribute
279+
dmsg << "Finished distributing BC data" << std::endl;
280+
dmsg << "Number of face nodes on this processor: " << local_num_nodes_ << std::endl;
281+
#endif
282+
}
270283

284+
void BoundaryCondition::distribute_metadata(const CmMod& cm_mod, const cmType& cm, bool is_slave)
285+
{
271286
cm.bcast(cm_mod, &spatially_variable);
272287

273288
// Not necessary, but we do it for consistency
@@ -294,172 +309,166 @@ void BoundaryCondition::distribute(const ComMod& com_mod, const CmMod& cm_mod, c
294309
array_names_[i] = array_name;
295310
}
296311
}
312+
}
297313

298-
// Communicate array values needed by each process
299-
if (spatially_variable) {
300-
// Setup
301-
if (face_ == nullptr) {
302-
throw std::runtime_error("face_ is nullptr during distribute");
303-
}
304-
305-
// Each processor collects the global node IDs and nodal positions of its
306-
// associated face portion
307-
Vector<int> local_global_ids = face_->gN;
308-
Array<double> local_positions(3, local_num_nodes_);
309-
for (int i = 0; i < local_num_nodes_; i++) {
310-
local_positions.set_col(i, com_mod.x.col(face_->gN(i)));
311-
}
312-
313-
#ifdef debug_bc_distribute
314-
dmsg << "Number of face nodes on this processor: " << local_num_nodes_ << std::endl;
315-
dmsg << "Local global IDs: " << local_global_ids << std::endl;
316-
dmsg << "Local positions: " << local_positions << std::endl;
317-
#endif
318-
319-
// Gather number of face nodes from each processor to master
320-
Vector<int> proc_num_nodes(cm.np());
321-
cm.gather(cm_mod, &local_num_nodes_, 1, proc_num_nodes.data(), 1, 0);
322-
323-
// Calculate displacements for gatherv/scatterv and compute total number of nodes
324-
// total_num_nodes is the total number of face nodes across all processors.
325-
Vector<int> displs(cm.np());
326-
int total_num_nodes = 0;
327-
for (int i = 0; i < cm.np(); i++) {
328-
displs(i) = total_num_nodes;
329-
total_num_nodes += proc_num_nodes(i);
330-
}
314+
void BoundaryCondition::distribute_spatially_variable(const ComMod& com_mod, const CmMod& cm_mod, const cmType& cm, bool is_slave)
315+
{
316+
#define n_debug_distribute_spatially_variable
317+
#ifdef debug_distribute_spatially_variable
318+
DebugMsg dmsg(__func__, 0);
319+
#endif
331320

332-
// Master process: gather the nodal positions of face nodes from all processors,
333-
// get the corresponding array values by matching the positions to the VTP points,
334-
// and scatter the data back to all processors.
335-
Array<double> all_positions;
336-
std::map<std::string, Vector<double>> all_values;
337-
if (!is_slave) {
338-
// Resize receive buffers based on total number of nodes
339-
all_positions.resize(3, total_num_nodes);
340-
341-
// Gather all positions to master using gatherv
342-
for (int d = 0; d < 3; d++) {
343-
Vector<double> local_pos_d(local_num_nodes_);
344-
Vector<double> all_pos_d(total_num_nodes);
345-
for (int i = 0; i < local_num_nodes_; i++) {
346-
local_pos_d(i) = local_positions(d,i);
347-
}
348-
cm.gatherv(cm_mod, local_pos_d, all_pos_d, proc_num_nodes, displs, 0);
349-
for (int i = 0; i < total_num_nodes; i++) {
350-
all_positions(d,i) = all_pos_d(i);
351-
}
352-
}
321+
if (face_ == nullptr) {
322+
throw std::runtime_error("face_ is nullptr during distribute");
323+
}
324+
// Each processor collects the global node IDs and nodal positions of its
325+
// associated face portion
326+
Vector<int> local_global_ids = face_->gN;
327+
Array<double> local_positions(3, local_num_nodes_);
328+
for (int i = 0; i < local_num_nodes_; i++) {
329+
local_positions.set_col(i, com_mod.x.col(face_->gN(i)));
330+
}
353331

354-
// Get VTP points for position matching
355-
Array<double> vtp_points = vtp_data_->get_points();
356-
357-
// Look up data for all nodes using point matching
358-
for (const auto& array_name : array_names_) {
359-
all_values[array_name].resize(total_num_nodes);
360-
for (int i = 0; i < total_num_nodes; i++) {
361-
int vtp_idx = find_vtp_point_index(all_positions(0,i), all_positions(1,i), all_positions(2,i), vtp_points);
362-
all_values[array_name](i) = global_data_[array_name](vtp_idx, 0);
363-
}
332+
#ifdef debug_distribute_spatially_variable
333+
dmsg << "Number of face nodes on this processor: " << local_num_nodes_ << std::endl;
334+
dmsg << "Local global IDs: " << local_global_ids << std::endl;
335+
dmsg << "Local positions: " << local_positions << std::endl;
336+
#endif
337+
// Gather number of face nodes from each processor to master
338+
Vector<int> proc_num_nodes(cm.np());
339+
cm.gather(cm_mod, &local_num_nodes_, 1, proc_num_nodes.data(), 1, 0);
340+
341+
// Calculate displacements for gatherv/scatterv and compute total number of nodes
342+
// total_num_nodes is the total number of face nodes across all processors.
343+
Vector<int> displs(cm.np());
344+
int total_num_nodes = 0;
345+
for (int i = 0; i < cm.np(); i++) {
346+
displs(i) = total_num_nodes;
347+
total_num_nodes += proc_num_nodes(i);
348+
}
349+
350+
// Master process: gather the nodal positions of face nodes from all processors,
351+
// get the corresponding array values by matching the positions to the VTP points,
352+
// and scatter the data back to all processors.
353+
Array<double> all_positions;
354+
std::map<std::string, Vector<double>> all_values;
355+
if (!is_slave) {
356+
// Resize receive buffers based on total number of nodes
357+
all_positions.resize(3, total_num_nodes);
358+
359+
// Gather all positions to master using gatherv
360+
for (int d = 0; d < 3; d++) {
361+
Vector<double> local_pos_d(local_num_nodes_);
362+
Vector<double> all_pos_d(total_num_nodes);
363+
for (int i = 0; i < local_num_nodes_; i++) {
364+
local_pos_d(i) = local_positions(d,i);
364365
}
365-
366-
// Clear global data to save memory
367-
global_data_.clear();
368-
369-
} else {
370-
// Slave processes: send positions to master
371-
for (int d = 0; d < 3; d++) {
372-
Vector<double> local_pos_d(local_num_nodes_);
373-
for (int i = 0; i < local_num_nodes_; i++) {
374-
local_pos_d(i) = local_positions(d,i);
375-
}
376-
Vector<double> dummy_recv(total_num_nodes);
377-
cm.gatherv(cm_mod, local_pos_d, dummy_recv, proc_num_nodes, displs, 0);
366+
cm.gatherv(cm_mod, local_pos_d, all_pos_d, proc_num_nodes, displs, 0);
367+
for (int i = 0; i < total_num_nodes; i++) {
368+
all_positions(d,i) = all_pos_d(i);
378369
}
379370
}
380-
381-
// Scatter data back to all processes using scatterv
382-
local_data_.clear();
371+
372+
// Get VTP points for position matching
373+
Array<double> vtp_points = vtp_data_->get_points();
374+
375+
// Look up data for all nodes using point matching
383376
for (const auto& array_name : array_names_) {
384-
Vector<double> local_values(local_num_nodes_);
385-
cm.scatterv(cm_mod, all_values[array_name], proc_num_nodes, displs, local_values, 0);
386-
local_data_[array_name] = Array<double>(local_num_nodes_, 1);
387-
local_data_[array_name].set_col(0, local_values);
388-
}
389-
390-
// Build mapping from face global node IDs to local array indices so we can
391-
// get data from a global node ID
392-
global_node_map_.clear();
393-
for (int i = 0; i < local_num_nodes_; i++) {
394-
global_node_map_[local_global_ids(i)] = i;
395-
}
396-
397-
#ifdef debug_bc_distribute
398-
dmsg << "Checking if local arrays and node positions are consistent" << std::endl;
399-
for (int i = 0; i < local_num_nodes_; i++) {
400-
dmsg << "Local global ID: " << local_global_ids(i) << std::endl;
401-
dmsg << "Local index: " << get_local_index(local_global_ids(i)) << std::endl;
402-
dmsg << "Local position: " << com_mod.x.col(local_global_ids(i)) << std::endl;
403-
for (const auto& array_name : array_names_) {
404-
dmsg << "Local " << array_name << ": " << local_data_[array_name](i, 0) << std::endl;
377+
all_values[array_name].resize(total_num_nodes);
378+
for (int i = 0; i < total_num_nodes; i++) {
379+
int vtp_idx = find_vtp_point_index(all_positions(0,i), all_positions(1,i), all_positions(2,i), vtp_points);
380+
all_values[array_name](i) = global_data_[array_name](vtp_idx, 0);
405381
}
406382
}
407-
#endif
408-
383+
384+
// Clear global data to save memory
385+
global_data_.clear();
409386
} else {
410-
// For uniform values, just broadcast the single values
411-
if (!is_slave) {
412-
for (const auto& array_name : array_names_) {
413-
double uniform_value = local_data_[array_name](0, 0);
414-
cm.bcast(cm_mod, &uniform_value);
415-
}
416-
} else {
417-
local_data_.clear();
418-
for (const auto& array_name : array_names_) {
419-
double uniform_value;
420-
cm.bcast(cm_mod, &uniform_value);
421-
local_data_[array_name] = Array<double>(1, 1);
422-
local_data_[array_name](0, 0) = uniform_value;
387+
// Slave processes: send node positions to master
388+
for (int d = 0; d < 3; d++) {
389+
Vector<double> local_pos_d(local_num_nodes_);
390+
for (int i = 0; i < local_num_nodes_; i++) {
391+
local_pos_d(i) = local_positions(d,i);
423392
}
393+
Vector<double> dummy_recv(total_num_nodes);
394+
cm.gatherv(cm_mod, local_pos_d, dummy_recv, proc_num_nodes, displs, 0);
424395
}
425396
}
397+
398+
// Scatter data back to all processes using scatterv
399+
local_data_.clear();
400+
for (const auto& array_name : array_names_) {
401+
Vector<double> local_values(local_num_nodes_);
402+
cm.scatterv(cm_mod, all_values[array_name], proc_num_nodes, displs, local_values, 0);
403+
local_data_[array_name] = Array<double>(local_num_nodes_, 1);
404+
local_data_[array_name].set_col(0, local_values);
405+
}
406+
407+
// Build mapping from face global node IDs to local array indices so we can
408+
// get data from a global node ID
409+
global_node_map_.clear();
410+
for (int i = 0; i < local_num_nodes_; i++) {
411+
global_node_map_[local_global_ids(i)] = i;
412+
}
426413

427-
// Broadcast boolean flags map
428-
if (!cm.seq()) {
429-
int num_flags = 0;
430-
if (!is_slave) {
431-
num_flags = static_cast<int>(flags_.size());
414+
#ifdef debug_distribute_spatially_variable
415+
dmsg << "Checking if local arrays and node positions are consistent" << std::endl;
416+
for (int i = 0; i < local_num_nodes_; i++) {
417+
dmsg << "Local global ID: " << local_global_ids(i) << std::endl;
418+
dmsg << "Local index: " << get_local_index(local_global_ids(i)) << std::endl;
419+
dmsg << "Local position: " << com_mod.x.col(local_global_ids(i)) << std::endl;
420+
for (const auto& array_name : array_names_) {
421+
dmsg << "Local " << array_name << ": " << local_data_[array_name](i, 0) << std::endl;
432422
}
433-
cm.bcast(cm_mod, &num_flags);
434-
if (is_slave) {
435-
flags_.clear();
423+
}
424+
#endif
425+
}
426+
427+
void BoundaryCondition::distribute_uniform(const CmMod& cm_mod, const cmType& cm, bool is_slave)
428+
{
429+
if (!is_slave) {
430+
for (const auto& array_name : array_names_) {
431+
double uniform_value = local_data_[array_name](0, 0);
432+
cm.bcast(cm_mod, &uniform_value);
436433
}
437-
for (int i = 0; i < num_flags; i++) {
438-
std::string key;
439-
bool val = false;
440-
if (!is_slave) {
441-
auto it = std::next(flags_.begin(), i);
442-
key = it->first;
443-
val = it->second;
444-
cm.bcast(cm_mod, key);
445-
cm.bcast(cm_mod, &val);
446-
} else {
447-
cm.bcast(cm_mod, key);
448-
cm.bcast(cm_mod, &val);
449-
flags_[key] = val;
450-
}
434+
} else {
435+
local_data_.clear();
436+
for (const auto& array_name : array_names_) {
437+
double uniform_value;
438+
cm.bcast(cm_mod, &uniform_value);
439+
local_data_[array_name] = Array<double>(1, 1);
440+
local_data_[array_name](0, 0) = uniform_value;
451441
}
452442
}
453-
454-
// Mark as defined
455-
defined_ = true;
456-
457-
#ifdef debug_bc_distribute
458-
dmsg << "Finished distributing BC data" << std::endl;
459-
dmsg << "Number of face nodes on this processor: " << local_num_nodes_ << std::endl;
460-
#endif
461443
}
462444

445+
void BoundaryCondition::distribute_flags(const CmMod& cm_mod, const cmType& cm, bool is_slave)
446+
{
447+
if (cm.seq()) return;
448+
int num_flags = 0;
449+
if (!is_slave) {
450+
num_flags = static_cast<int>(flags_.size());
451+
}
452+
cm.bcast(cm_mod, &num_flags);
453+
if (is_slave) {
454+
flags_.clear();
455+
}
456+
for (int i = 0; i < num_flags; i++) {
457+
std::string key;
458+
bool val = false;
459+
if (!is_slave) {
460+
auto it = std::next(flags_.begin(), i);
461+
key = it->first;
462+
val = it->second;
463+
cm.bcast(cm_mod, key);
464+
cm.bcast(cm_mod, &val);
465+
} else {
466+
cm.bcast(cm_mod, key);
467+
cm.bcast(cm_mod, &val);
468+
flags_[key] = val;
469+
}
470+
}
471+
}
463472

464473
int BoundaryCondition::find_vtp_point_index(double x, double y, double z,
465474
const Array<double>& vtp_points) const

Code/Source/solver/BoundaryCondition.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,13 @@ class BoundaryCondition {
176176
/// @param value Value to validate
177177
/// @throws std::runtime_error if validation fails
178178
virtual void validate_array_value(const std::string& array_name, double value) const {}
179+
180+
// ---- distribute helpers ----
181+
void distribute_metadata(const CmMod& cm_mod, const cmType& cm, bool is_slave);
182+
void distribute_spatially_variable(const ComMod& com_mod, const CmMod& cm_mod, const cmType& cm, bool is_slave);
183+
void distribute_uniform(const CmMod& cm_mod, const cmType& cm, bool is_slave);
184+
void distribute_flags(const CmMod& cm_mod, const cmType& cm, bool is_slave);
185+
179186
};
180187

181188
/// @brief Base exception class for BC errors

0 commit comments

Comments
 (0)