Skip to content

Commit 833a631

Browse files
committed
Added check to deal with poorly performing CubedSphereIterator.
1 parent 5d0f545 commit 833a631

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

src/atlas/grid/detail/partitioner/MatchingMeshPartitionerCubedSphere.cc

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
*/
77

88
#include "atlas/grid/detail/partitioner/MatchingMeshPartitionerCubedSphere.h"
9+
10+
#include <algorithm>
11+
#include <iterator>
12+
#include <vector>
13+
914
#include "atlas/grid/CubedSphereGrid.h"
1015
#include "atlas/grid/Iterator.h"
1116
#include "atlas/interpolation/method/cubedsphere/CellFinder.h"
1217
#include "atlas/parallel/mpi/mpi.h"
1318
#include "atlas/parallel/omp/omp.h"
19+
#include "atlas/util/Point.h"
1420

1521
namespace atlas {
1622
namespace grid {
@@ -30,13 +36,24 @@ void MatchingMeshPartitionerCubedSphere::partition(const Grid& grid, int partiti
3036
const auto edgeEpsilon = epsilon;
3137
const size_t listSize = 8;
3238

33-
// Loop over grid and set partioning[].
34-
const auto lonlatIt = grid.lonlat().begin();
35-
atlas_omp_parallel_for(gidx_t i = 0; i < grid.size(); ++i) {
36-
const auto lonLat = *(lonlatIt + i);
37-
// This is probably more expensive than it needs to be, as it performs
38-
// a dry run of the cubedsphere interpolation method.
39-
partitioning[i] = finder.getCell(lonLat, listSize, edgeEpsilon, epsilon).isect ? mpi_rank : -1;
39+
40+
const auto setPartitioning = [&](const auto& lonLatIt) {
41+
atlas_omp_parallel_for(gidx_t i = 0; i < grid.size(); ++i) {
42+
const auto lonLat = *(lonLatIt + i);
43+
// This is probably more expensive than it needs to be, as it performs
44+
// a dry run of the cubedsphere interpolation method.
45+
partitioning[i] = finder.getCell(lonLat, listSize, edgeEpsilon, epsilon).isect ? mpi_rank : -1;
46+
}
47+
};
48+
49+
// CubedSphereIterator::operator+=() is not implemented properly.
50+
if (CubedSphereGrid(grid)) {
51+
auto lonLats = std::vector<PointLonLat>{};
52+
lonLats.reserve(grid.size());
53+
std::copy(grid.lonlat().begin(), grid.lonlat().end(), std::back_inserter(lonLats));
54+
setPartitioning(lonLats.begin());
55+
} else {
56+
setPartitioning(grid.lonlat().begin());
4057
}
4158

4259
// AllReduce to get full partitioning array.

0 commit comments

Comments
 (0)