diff --git a/ext/DataHandlingExt.jl b/ext/DataHandlingExt.jl index b05b3d0b..e0869cb0 100644 --- a/ext/DataHandlingExt.jl +++ b/ext/DataHandlingExt.jl @@ -581,7 +581,6 @@ function DataHandling.regridded_snapshot( end regridder_type = nameof(typeof(data_handler.regridder)) - regrid_args = () # Check if the regridded field at this date is already in the cache return get!(data_handler._cached_regridded_fields, date) do diff --git a/ext/InterpolationsRegridderExt.jl b/ext/InterpolationsRegridderExt.jl index c3e4bd59..74cd5aee 100644 --- a/ext/InterpolationsRegridderExt.jl +++ b/ext/InterpolationsRegridderExt.jl @@ -7,12 +7,15 @@ import ClimaCore.Fields: Adapt import ClimaCore.Fields: ClimaComms import ClimaUtilities.Regridders +import ClimaUtilities.Utils: unwrap struct InterpolationsRegridder{ SPACE <: ClimaCore.Spaces.AbstractSpace, FIELD <: ClimaCore.Fields.Field, BC, DT <: Tuple, + DI <: Tuple, + N, } <: Regridders.AbstractRegridder """ClimaCore.Space where the output Field will be defined""" @@ -27,6 +30,12 @@ struct InterpolationsRegridder{ """Tuple of booleans signifying if the dimension is monotonically increasing. True for dimensions that are monotonically increasing, false for dimensions that are monotonically decreasing.""" dim_increasing::DT + + """Tuple of integers indicating which dimensions to reverse in data""" + decreasing_indices::DI + + "Number of dimensions of the target space" + num_space_dims::N end # Note, we swap Lat and Long! This is because according to the CF conventions longitude @@ -74,23 +83,32 @@ function Regridders.InterpolationsRegridder( isnothing(extrapolation_bc) && (extrapolation_bc = (Intp.Periodic(), Intp.Flat())) isnothing(dim_increasing) && (dim_increasing = (true, true)) + num_space_dims = Val(2) elseif eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint isnothing(extrapolation_bc) && (extrapolation_bc = (Intp.Periodic(), Intp.Flat(), Intp.Throw())) isnothing(dim_increasing) && (dim_increasing = (true, true, true)) + num_space_dims = Val(3) elseif eltype(coordinates) <: ClimaCore.Geometry.XYZPoint isnothing(extrapolation_bc) && (extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Throw())) isnothing(dim_increasing) && (dim_increasing = (true, true, true)) + num_space_dims = Val(3) else error("Only lat-long, lat-long-z, and x-y-z spaces are supported") end + decreasing_indices = + !all(dim_increasing) ? + Tuple([i for (i, d) in enumerate(dim_increasing) if !d]) : () + return InterpolationsRegridder( target_space, coordinates, extrapolation_bc, dim_increasing, + decreasing_indices, + num_space_dims, ) end @@ -102,17 +120,26 @@ Regrid the given data as defined on the given dimensions to the `target_space` i This function is allocating. """ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions) + num_data_dims = ndims(data) + num_dims = length(dimensions) + num_space_dims = unwrap(regridder.num_space_dims) + ((num_space_dims != num_data_dims) || (num_space_dims != num_dims)) && + error( + "Number of dimensions of data ($num_data_dims) does not match the dimension of the space ($num_space_dims) or the number of dimensions passed in ($num_dims)", + ) + FT = ClimaCore.Spaces.undertype(regridder.target_space) - dimensions_FT = map(dimensions, regridder.dim_increasing) do dim, increasing - !increasing ? reverse(FT.(dim)) : FT.(dim) - end + dimensions_FT = ntuple( + i -> + !regridder.dim_increasing[i] ? reverse(FT.(dimensions[i])) : + FT.(dimensions[i]), + regridder.num_space_dims, + ) data_transformed = data # Reverse the data if needed. This allocates, so ideally it should be done in preprocessing if !all(regridder.dim_increasing) - decreasing_indices = - Tuple([i for (i, d) in enumerate(regridder.dim_increasing) if !d]) - data_transformed = reverse(data, dims = decreasing_indices) + data_transformed = reverse(data, dims = regridder.decreasing_indices) end # Make a linear spline itp = Intp.extrapolate( diff --git a/src/Utils.jl b/src/Utils.jl index 85ad67a1..fd06a6fc 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -412,4 +412,11 @@ function sort_by_creation_time(files) return sort(files, by = x -> stat(x).ctime) end +""" + unwrap(x::Val{N}) + +Unwrap the value in `Val{N}` to get `N`. +""" +unwrap(x::Val{N}) where {N} = N + end diff --git a/test/data_handling.jl b/test/data_handling.jl index ad288d1d..bcb98155 100644 --- a/test/data_handling.jl +++ b/test/data_handling.jl @@ -144,13 +144,10 @@ ClimaComms.init(context) target_space; regridder_type = :InterpolationsRegridder, file_reader_kwargs = (; preprocess_func = (data) -> 0.0 * data), - regridder_kwargs = (; - extrapolation_bc = (Intp.Flat(), Intp.Flat(), Intp.Flat()) - ), + regridder_kwargs = (; extrapolation_bc = (Intp.Flat(), Intp.Flat())), ) - @test data_handler.regridder.extrapolation_bc == - (Intp.Flat(), Intp.Flat(), Intp.Flat()) + @test data_handler.regridder.extrapolation_bc == (Intp.Flat(), Intp.Flat()) field = DataHandling.regridded_snapshot(data_handler) @test extrema(field) == (0.0, 0.0) end diff --git a/test/regridders.jl b/test/regridders.jl index e6bd42cc..55ee3c5b 100644 --- a/test/regridders.jl +++ b/test/regridders.jl @@ -129,6 +129,21 @@ end @test regridded_lat_reversed == regridded_lat @test regridded_lon_reversed == regridded_lon @test regridded_z_reversed == regridded_z + + # Error handling + data = [1.0, 2.0, 3.0, 4.0] + @test_throws ErrorException Regridders.regrid( + reg_hv_reversed, + data, + dimensions3D_reversed, + ) + + dimensions = ([1.0, 2.0, 3.0],) + @test_throws ErrorException Regridders.regrid( + reg_hv_reversed, + data_z3D_reversed, + dimensions, + ) end @testset "InterpolationsRegridder" begin diff --git a/test/utils.jl b/test/utils.jl index 26aab8d9..abc44d54 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,7 +12,8 @@ import ClimaUtilities.Utils: bounding_dates, period_to_seconds_float, unique_periods, - sort_by_creation_time + sort_by_creation_time, + unwrap @testset "searchsortednearest" begin A = 10 * collect(range(1, 10)) @@ -170,4 +171,12 @@ end sorted_files = sort_by_creation_time(files) @test sorted_files == [files[2], files[3], files[1]] end + + @testset "unwrap" begin + x = Val(2) + y = Val(3) + @test unwrap(x) == 2 + @test unwrap(y) == 3 + end + end