@@ -744,29 +744,6 @@ function _reset_interpolated_values!(remapper::Remapper)
744
744
fill! (remapper. _interpolated_values, 0 )
745
745
end
746
746
747
- """
748
- _collect_and_return_interpolated_values!(remapper::Remapper,
749
- num_fields::Int)
750
-
751
- Perform an MPI call to aggregate the interpolated points from all the MPI processes and save
752
- the result in the local state of the `remapper`. Only the root process will return the
753
- interpolated data.
754
-
755
- `_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays.
756
-
757
- `num_fields` is the number of fields that have been interpolated in this batch.
758
- """
759
- function _collect_and_return_interpolated_values! (
760
- remapper:: Remapper ,
761
- num_fields:: Int ,
762
- )
763
- return ClimaComms. reduce (
764
- remapper. comms_ctx,
765
- remapper. _interpolated_values[remapper. colons... , 1 : num_fields],
766
- + ,
767
- )
768
- end
769
-
770
747
function _collect_interpolated_values! (
771
748
dest,
772
749
remapper:: Remapper ,
@@ -777,38 +754,26 @@ function _collect_interpolated_values!(
777
754
if only_one_field
778
755
ClimaComms. reduce! (
779
756
remapper. comms_ctx,
780
- remapper. _interpolated_values[ remapper. colons... , begin ] ,
757
+ view ( remapper. _interpolated_values, remapper. colons... , 1 ) ,
781
758
dest,
782
759
+ ,
783
760
)
784
- return nothing
761
+ else
762
+ num_fields = 1 + index_field_end - index_field_begin
763
+ ClimaComms. reduce! (
764
+ remapper. comms_ctx,
765
+ view (
766
+ remapper. _interpolated_values,
767
+ remapper. colons... ,
768
+ 1 : num_fields,
769
+ ),
770
+ view (dest, remapper. colons... , index_field_begin: index_field_end),
771
+ + ,
772
+ )
785
773
end
786
-
787
- num_fields = 1 + index_field_end - index_field_begin
788
-
789
- ClimaComms. reduce! (
790
- remapper. comms_ctx,
791
- view (remapper. _interpolated_values, remapper. colons... , 1 : num_fields),
792
- view (dest, remapper. colons... , index_field_begin: index_field_end),
793
- + ,
794
- )
795
-
796
774
return nothing
797
775
end
798
776
799
- """
800
- batched_ranges(num_fields, buffer_length)
801
-
802
- Partition the indices from 1 to num_fields in such a way that no range is larger than
803
- buffer_length.
804
- """
805
- function batched_ranges (num_fields, buffer_length)
806
- return [
807
- (i * buffer_length + 1 ): (min ((i + 1 ) * buffer_length, num_fields)) for
808
- i in 0 : (div ((num_fields - 1 ), buffer_length))
809
- ]
810
- end
811
-
812
777
"""
813
778
interpolate(remapper::Remapper, fields)
814
779
interpolate!(dest, remapper::Remapper, fields)
@@ -860,58 +825,21 @@ int12 = interpolate(remapper, [field1, field2])
860
825
```
861
826
"""
862
827
function interpolate (remapper:: Remapper , fields)
863
-
828
+ ArrayType = ClimaComms. array_type (remapper. space)
829
+ FT = Spaces. undertype (remapper. space)
864
830
only_one_field = fields isa Fields. Field
865
- if only_one_field
866
- fields = [fields]
867
- end
868
831
869
- for field in fields
870
- axes (field) == remapper. space ||
871
- error (" Field is defined on a different space than remapper" )
872
- end
832
+ interpolated_values_dim... , _buffer_length =
833
+ size (remapper. _interpolated_values)
873
834
874
- isa_vertical_space = remapper. space isa Spaces. FiniteDifferenceSpace
875
-
876
- index_field_begin, index_field_end =
877
- 1 , min (length (fields), remapper. buffer_length)
878
-
879
- # Partition the indices in such a way that nothing is larger than
880
- # buffer_length
881
- index_ranges = batched_ranges (length (fields), remapper. buffer_length)
835
+ allocate_extra = only_one_field ? () : (length (fields),)
836
+ dest = ArrayType (zeros (FT, interpolated_values_dim... , allocate_extra... ))
882
837
883
- cat_fn = (l... ) -> cat (l... , dims = length (remapper. colons) + 1 )
884
-
885
- interpolated_values = mapreduce (cat_fn, index_ranges) do range
886
- num_fields = length (range)
887
-
888
- # Reset interpolated_values. This is needed because we collect distributed results
889
- # with a + reduction.
890
- _reset_interpolated_values! (remapper)
891
- # Perform the interpolations (horizontal and vertical)
892
- _set_interpolated_values! (
893
- remapper,
894
- view (fields, index_field_begin: index_field_end),
895
- )
896
-
897
- if ! isa_vertical_space
898
- # For spaces with an horizontal component, reshape the output so that it is a nice grid.
899
- _apply_mpi_bitmask! (remapper, num_fields)
900
- else
901
- # For purely vertical spaces, just move to _interpolated_values
902
- remapper. _interpolated_values .= remapper. _local_interpolated_values
903
- end
904
-
905
- # Finally, we have to send all the _interpolated_values to root and sum them up to
906
- # obtain the final answer. Only the root will contain something useful.
907
- return _collect_and_return_interpolated_values! (remapper, num_fields)
908
- end
909
-
910
- # Non-root processes
911
- isnothing (interpolated_values) && return nothing
912
-
913
- return only_one_field ? interpolated_values[remapper. colons... , begin ] :
914
- interpolated_values
838
+ # interpolate! has an MPI call, so it is important to return after it is
839
+ # called, not before!
840
+ interpolate! (dest, remapper, fields)
841
+ ClimaComms. iamroot (remapper. comms_ctx) || return nothing
842
+ return dest
915
843
end
916
844
917
845
# dest has to be allowed to be nothing because interpolation happens only on the root
@@ -927,6 +855,11 @@ function interpolate!(
927
855
end
928
856
isa_vertical_space = remapper. space isa Spaces. FiniteDifferenceSpace
929
857
858
+ for field in fields
859
+ axes (field) == remapper. space ||
860
+ error (" Field is defined on a different space than remapper" )
861
+ end
862
+
930
863
if ! isnothing (dest)
931
864
# !isnothing(dest) means that this is the root process, in this case, the size have
932
865
# to match (ignoring the buffer_length)
0 commit comments