Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ext/DataHandlingExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ function DataHandling.regridded_snapshot(
end

regridder_type = nameof(typeof(data_handler.regridder))
regrid_args = ()

# Check if the regridded field at this date is already in the cache
return get!(data_handler._cached_regridded_fields, date) do
Expand Down
39 changes: 33 additions & 6 deletions ext/InterpolationsRegridderExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import ClimaCore.Fields: Adapt
import ClimaCore.Fields: ClimaComms

import ClimaUtilities.Regridders
import ClimaUtilities.Utils: unwrap

struct InterpolationsRegridder{
SPACE <: ClimaCore.Spaces.AbstractSpace,
FIELD <: ClimaCore.Fields.Field,
BC,
DT <: Tuple,
DI <: Tuple,
N,
} <: Regridders.AbstractRegridder

"""ClimaCore.Space where the output Field will be defined"""
Expand All @@ -27,6 +30,12 @@ struct InterpolationsRegridder{
"""Tuple of booleans signifying if the dimension is monotonically increasing. True for
dimensions that are monotonically increasing, false for dimensions that are monotonically decreasing."""
dim_increasing::DT

"""Tuple of integers indicating which dimensions to reverse in data"""
decreasing_indices::DI

"Number of dimensions of the target space"
num_space_dims::N
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's eliminate this, and instead write a function that computes this on the fly:

num_space_dims(regridder::InterpolationsRegridder) =
    num_space_dims(regridder.target_space)
num_space_dims(space::Spaces.AbstractSpace) =
    num_space_dims(eltype(Fields.coordinate_field(space)))
num_space_dims(::Type{S}) where {S <: Geometry.LatLongPoint} = 2
num_space_dims(::Type{S}) where {S <: Geometry.LatLongZPoint} = 3
num_space_dims(::Type{S}) where {S <: Geometry.XYZPoint} = 3

It'd be nice if we could do the same for the decreasing_indices, but I don't think it's possible since it depends on values of the input data.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also allow us to eliminate unwrap and its unit tests. We could add unit tests for this, of course, instead.

end

# Note, we swap Lat and Long! This is because according to the CF conventions longitude
Expand Down Expand Up @@ -74,23 +83,32 @@ function Regridders.InterpolationsRegridder(
isnothing(extrapolation_bc) &&
(extrapolation_bc = (Intp.Periodic(), Intp.Flat()))
isnothing(dim_increasing) && (dim_increasing = (true, true))
num_space_dims = Val(2)
elseif eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint
isnothing(extrapolation_bc) &&
(extrapolation_bc = (Intp.Periodic(), Intp.Flat(), Intp.Throw()))
isnothing(dim_increasing) && (dim_increasing = (true, true, true))
num_space_dims = Val(3)
elseif eltype(coordinates) <: ClimaCore.Geometry.XYZPoint
isnothing(extrapolation_bc) &&
(extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Throw()))
isnothing(dim_increasing) && (dim_increasing = (true, true, true))
num_space_dims = Val(3)
else
error("Only lat-long, lat-long-z, and x-y-z spaces are supported")
end

decreasing_indices =
!all(dim_increasing) ?
Tuple([i for (i, d) in enumerate(dim_increasing) if !d]) : ()

return InterpolationsRegridder(
target_space,
coordinates,
extrapolation_bc,
dim_increasing,
decreasing_indices,
num_space_dims,
)
end

Expand All @@ -102,17 +120,26 @@ Regrid the given data as defined on the given dimensions to the `target_space` i
This function is allocating.
"""
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
num_data_dims = ndims(data)
num_dims = length(dimensions)
num_space_dims = unwrap(regridder.num_space_dims)
((num_space_dims != num_data_dims) || (num_space_dims != num_dims)) &&
error(
"Number of dimensions of data ($num_data_dims) does not match the dimension of the space ($num_space_dims) or the number of dimensions passed in ($num_dims)",
)

FT = ClimaCore.Spaces.undertype(regridder.target_space)
dimensions_FT = map(dimensions, regridder.dim_increasing) do dim, increasing
!increasing ? reverse(FT.(dim)) : FT.(dim)
end
dimensions_FT = ntuple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a check to check that the number of dimensions match between the given data and the space?

i ->
!regridder.dim_increasing[i] ? reverse(FT.(dimensions[i])) :
FT.(dimensions[i]),
regridder.num_space_dims,
)
Comment on lines +133 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a more elegant way that we can write this?


data_transformed = data
# Reverse the data if needed. This allocates, so ideally it should be done in preprocessing
if !all(regridder.dim_increasing)
decreasing_indices =
Tuple([i for (i, d) in enumerate(regridder.dim_increasing) if !d])
data_transformed = reverse(data, dims = decreasing_indices)
data_transformed = reverse(data, dims = regridder.decreasing_indices)
end
# Make a linear spline
itp = Intp.extrapolate(
Expand Down
7 changes: 7 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,4 +412,11 @@ function sort_by_creation_time(files)
return sort(files, by = x -> stat(x).ctime)
end

"""
unwrap(x::Val{N})

Unwrap the value in `Val{N}` to get `N`.
"""
unwrap(x::Val{N}) where {N} = N

end
7 changes: 2 additions & 5 deletions test/data_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,10 @@ ClimaComms.init(context)
target_space;
regridder_type = :InterpolationsRegridder,
file_reader_kwargs = (; preprocess_func = (data) -> 0.0 * data),
regridder_kwargs = (;
extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Flat())
),
regridder_kwargs = (; extrapolation_bc = (Intp.Flat(), Intp.Flat())),
)

@test data_handler.regridder.extrapolation_bc ==
(Intp.Flat(), Intp.Flat(), Intp.Flat())
@test data_handler.regridder.extrapolation_bc == (Intp.Flat(), Intp.Flat())
field = DataHandling.regridded_snapshot(data_handler)
@test extrema(field) == (0.0, 0.0)
end
Expand Down
15 changes: 15 additions & 0 deletions test/regridders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ end
@test regridded_lat_reversed == regridded_lat
@test regridded_lon_reversed == regridded_lon
@test regridded_z_reversed == regridded_z

# Error handling
data = [1.0, 2.0, 3.0, 4.0]
@test_throws ErrorException Regridders.regrid(
reg_hv_reversed,
data,
dimensions3D_reversed,
)

dimensions = ([1.0, 2.0, 3.0],)
@test_throws ErrorException Regridders.regrid(
reg_hv_reversed,
data_z3D_reversed,
dimensions,
)
end

@testset "InterpolationsRegridder" begin
Expand Down
11 changes: 10 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import ClimaUtilities.Utils:
bounding_dates,
period_to_seconds_float,
unique_periods,
sort_by_creation_time
sort_by_creation_time,
unwrap

@testset "searchsortednearest" begin
A = 10 * collect(range(1, 10))
Expand Down Expand Up @@ -170,4 +171,12 @@ end
sorted_files = sort_by_creation_time(files)
@test sorted_files == [files[2], files[3], files[1]]
end

@testset "unwrap" begin
x = Val(2)
y = Val(3)
@test unwrap(x) == 2
@test unwrap(y) == 3
end

end
Loading