@@ -81,6 +81,7 @@ static std::string get_mpi_comm(const eckit::Configuration& config) {
8181
8282template <>
8383PointCloud::PointCloud (const std::vector<PointXY>& points, const eckit::Configuration& config) {
84+ ATLAS_TRACE (" PointCloud(std::vector<PointXY>, eckit::Configuration)" );
8485 mpi_comm_ = get_mpi_comm (config);
8586 lonlat_ = Field (" lonlat" , array::make_datatype<double >(), array::make_shape (points.size (), 2 ));
8687 auto lonlat = array::make_view<double , 2 >(lonlat_);
@@ -92,6 +93,7 @@ PointCloud::PointCloud(const std::vector<PointXY>& points, const eckit::Configur
9293
9394template <>
9495PointCloud::PointCloud (const std::vector<PointXYZ>& points, const eckit::Configuration& config) {
96+ ATLAS_TRACE (" PointCloud(std::vector<PointXYZ>, eckit::Configuration)" );
9597 mpi_comm_ = get_mpi_comm (config);
9698 lonlat_ = Field (" lonlat" , array::make_datatype<double >(), array::make_shape (points.size (), 2 ));
9799 vertical_ = Field (" vertical" , array::make_datatype<double >(), array::make_shape (points.size ()));
@@ -105,16 +107,18 @@ PointCloud::PointCloud(const std::vector<PointXYZ>& points, const eckit::Configu
105107}
106108
107109PointCloud::PointCloud (const Field& lonlat, const eckit::Configuration& config): lonlat_(lonlat) {
110+ ATLAS_TRACE (" PointCloud(Field lonlat, eckit::Configuration)" );
108111 mpi_comm_ = get_mpi_comm (config);
109112}
110113
111114PointCloud::PointCloud (const Field& lonlat, const Field& ghost, const eckit::Configuration& config): lonlat_(lonlat), ghost_(ghost) {
115+ ATLAS_TRACE (" PointCloud(Field lonlat, Field ghost, eckit::Configuration)" );
112116 mpi_comm_ = get_mpi_comm (config);
113- setupHaloExchange ();
114- setupGatherScatter ();
117+ setupParallel ();
115118}
116119
117120PointCloud::PointCloud (const FieldSet& flds, const eckit::Configuration& config): lonlat_(flds[" lonlat" ]) {
121+ ATLAS_TRACE (" PointCloud(Fieldset, eckit::Configuration)" );
118122 mpi_comm_ = get_mpi_comm (config);
119123 if (flds.has (" ghost" )) {
120124 ghost_ = flds[" ghost" ];
@@ -129,8 +133,7 @@ PointCloud::PointCloud(const FieldSet& flds, const eckit::Configuration& config)
129133 global_index_ = flds[" global_index" ];
130134 }
131135 if ( ghost_ && remote_index_ && partition_ ) {
132- setupHaloExchange ();
133- setupGatherScatter ();
136+ setupParallel ();
134137 }
135138}
136139
@@ -165,6 +168,8 @@ PointCloud::PointCloud(const Grid& grid, const grid::Distribution& distribution,
165168 auto size_owned = distribution.nb_pts ()[part_];
166169 size_owned_ = size_owned;
167170
171+ size_global_ = grid.size ();
172+
168173 if (halo_radius == 0 . || nb_partitions_ == 1 ) {
169174 idx_t size_halo = size_owned;
170175 ATLAS_ASSERT (size_owned > 0 );
@@ -180,10 +185,13 @@ PointCloud::PointCloud(const Grid& grid, const grid::Distribution& distribution,
180185 array::make_view<int ,1 >(ghost_).assign (0 );
181186 array::make_view<int ,1 >(partition_).assign (part_);
182187
188+ ATLAS_ASSERT (grid.size () == distribution.size ());
189+
183190 idx_t j{0 };
184191 gidx_t g{0 };
185192 for (auto p : grid.lonlat ()) {
186193 if ( distribution.partition (g) == part_ ) {
194+ ATLAS_ASSERT (j < size_halo);
187195 gidx (j) = g+1 ;
188196 ridx (j) = j;
189197 lonlat (j, 0 ) = p.lon ();
@@ -262,13 +270,16 @@ PointCloud::PointCloud(const Grid& grid, const grid::Distribution& distribution,
262270 }
263271
264272 }
265-
266- setupHaloExchange ();
267- setupGatherScatter ();
273+ setupParallel ();
268274}
269275
270276PointCloud::PointCloud (const Grid& grid, const grid::Partitioner& _partitioner, const eckit::Configuration& config):
271- PointCloud (grid, ((_partitioner) ? _partitioner : grid::Partitioner(" equal_regions" , util::Config(" mpi_comm" ,get_mpi_comm(config)))).partition(grid), config) {
277+ PointCloud (
278+ grid,
279+ grid::Distribution{grid, (_partitioner) ? _partitioner :
280+ grid::Partitioner{grid.partitioner () | util::Config (" mpi_comm" ,get_mpi_comm (config))}
281+ },
282+ config) {
272283 ATLAS_TRACE (" PointCloud(grid,partitioner,config)" );
273284}
274285
@@ -286,18 +297,18 @@ const Grid& PointCloud::grid() const {
286297 }
287298
288299 std::vector<PointXY> points;
289- points.reserve (size_global_ );
300+ points.reserve (size_global () );
290301 if (nb_partitions_ == 1 ) {
291302 for (const auto & point : iterate().xy ()) {
292303 points.push_back (point);
293304 }
294305 }
295306 else {
296307 std::vector<int > gidx;
297- gidx.reserve (size_global_ );
308+ gidx.reserve (size_global () );
298309 std::vector<double > x, y;
299- x.reserve (size_global_ );
300- y.reserve (size_global_ );
310+ x.reserve (size_global () );
311+ y.reserve (size_global () );
301312 const auto gidxView = array::make_view<gidx_t , 1 >(global_index_);
302313 const auto ghostView = array::make_view<int , 1 >(ghost_);
303314 int i = 0 ;
@@ -333,7 +344,7 @@ array::ArrayShape PointCloud::config_shape(const eckit::Configuration& config) c
333344 idx_t owner (0 );
334345 config.get (" owner" , owner);
335346 idx_t rank = mpi::comm (mpi_comm ()).rank ();
336- _size = (rank == owner ? size_global_ : 0 );
347+ _size = (rank == owner ? size_global () : 0 );
337348 }
338349 }
339350
@@ -381,6 +392,11 @@ std::string PointCloud::config_name(const eckit::Configuration& config) const {
381392}
382393
383394const parallel::HaloExchange& PointCloud::halo_exchange () const {
395+ if (halo_exchange_) {
396+ return *halo_exchange_;
397+ }
398+ const_cast <PointCloud&>(*this ).setupHaloExchange ();
399+ ATLAS_ASSERT (halo_exchange_);
384400 return *halo_exchange_;
385401}
386402
@@ -428,10 +444,18 @@ void PointCloud::gather(const Field& local, Field& global) const {
428444 gather (local_fields, global_fields);
429445}
430446const parallel::GatherScatter& PointCloud::gather () const {
447+ if (gather_scatter_) {
448+ return *gather_scatter_;
449+ }
450+ const_cast <PointCloud&>(*this ).setupGatherScatter ();
431451 ATLAS_ASSERT (gather_scatter_);
432452 return *gather_scatter_;
433453}
434454const parallel::GatherScatter& PointCloud::scatter () const {
455+ if (gather_scatter_) {
456+ return *gather_scatter_;
457+ }
458+ const_cast <PointCloud&>(*this ).setupGatherScatter ();
435459 ATLAS_ASSERT (gather_scatter_);
436460 return *gather_scatter_;
437461}
@@ -512,6 +536,10 @@ void PointCloud::set_field_metadata(const eckit::Configuration& config, Field& f
512536 if (config.has (" type" )) {
513537 field.metadata ().set (" type" , config.getString (" type" ));
514538 }
539+
540+ if (config.has (" vector_component" )) {
541+ field.metadata ().set (" vector_component" , config.getSubConfiguration (" vector_component" ));
542+ }
515543}
516544
517545Field PointCloud::createField (const eckit::Configuration& options) const {
@@ -618,7 +646,7 @@ void dispatch_adjointHaloExchange(Field& field, const parallel::HaloExchange& ha
618646} // namespace
619647
620648void PointCloud::haloExchange (const FieldSet& fieldset, bool on_device) const {
621- if (halo_exchange_ ) {
649+ if (parallel_ ) {
622650 for (idx_t f = 0 ; f < fieldset.size (); ++f) {
623651 Field& field = const_cast <FieldSet&>(fieldset)[f];
624652 switch (field.rank ()) {
@@ -891,7 +919,7 @@ void PointCloud::create_remote_index() const {
891919 }
892920}
893921
894- void PointCloud::setupHaloExchange () {
922+ void PointCloud::setupParallel () {
895923 ATLAS_TRACE ();
896924 if (ghost_ and partition_ and global_index_ and not remote_index_) {
897925 create_remote_index ();
@@ -1025,16 +1053,23 @@ void PointCloud::setupHaloExchange() {
10251053 ATLAS_ASSERT (remote_index_);
10261054 ATLAS_ASSERT (ghost_.size () == remote_index_.size ());
10271055 ATLAS_ASSERT (ghost_.size () == partition_.size ());
1028-
1029- halo_exchange_.reset (new parallel::HaloExchange ());
1030- halo_exchange_->setup (mpi_comm_,
1031- array::make_view<int , 1 >(partition_).data (),
1032- array::make_view<idx_t , 1 >(remote_index_).data (),
1033- REMOTE_IDX_BASE,
1034- ghost_.size ());
1056+ parallel_ = true ;
1057+ }
1058+
1059+ void PointCloud::setupHaloExchange () {
1060+ ATLAS_TRACE ();
1061+ if (ghost_ and partition_ and remote_index_) {
1062+ halo_exchange_.reset (new parallel::HaloExchange ());
1063+ halo_exchange_->setup (mpi_comm_,
1064+ array::make_view<int , 1 >(partition_).data (),
1065+ array::make_view<idx_t , 1 >(remote_index_).data (),
1066+ REMOTE_IDX_BASE,
1067+ ghost_.size ());
1068+ }
10351069}
10361070
10371071void PointCloud::setupGatherScatter () {
1072+ ATLAS_TRACE ();
10381073 if (ghost_ and partition_ and remote_index_ and global_index_) {
10391074 gather_scatter_.reset (new parallel::GatherScatter ());
10401075 gather_scatter_->setup (mpi_comm_,
@@ -1044,13 +1079,29 @@ void PointCloud::setupGatherScatter() {
10441079 array::make_view<gidx_t , 1 >(global_index_).data (),
10451080 array::make_view<int , 1 >(ghost_).data (),
10461081 ghost_.size ());
1047- size_global_ = gather_scatter_->glb_dof ();
10481082 }
10491083}
10501084
1085+ idx_t PointCloud::size_global () const {
1086+ if (size_global_ == -1 ) {
1087+ if (not parallel_) {
1088+ size_global_ = lonlat ().size ();
1089+ }
1090+ else {
1091+ if ( !gather_scatter_ ) {
1092+ const_cast <PointCloud&>(*this ).setupGatherScatter ();
1093+ }
1094+ if (gather_scatter_) {
1095+ size_global_ = gather_scatter_->glb_dof ();
1096+ }
1097+ }
1098+ }
1099+ return size_global_;
1100+ }
1101+
10511102
10521103void PointCloud::adjointHaloExchange (const FieldSet& fieldset, bool on_device) const {
1053- if (halo_exchange_ ) {
1104+ if (parallel_ ) {
10541105 for (idx_t f = 0 ; f < fieldset.size (); ++f) {
10551106 Field& field = const_cast <FieldSet&>(fieldset)[f];
10561107 switch (field.rank ()) {
0 commit comments