diff --git a/ext/OceananigansNCDatasetsExt.jl b/ext/OceananigansNCDatasetsExt.jl index 156c3b17cb..3847febea3 100644 --- a/ext/OceananigansNCDatasetsExt.jl +++ b/ext/OceananigansNCDatasetsExt.jl @@ -57,21 +57,41 @@ const BuoyancyBoussinesqEOSModel = BuoyancyForce{<:BoussinesqSeawaterBuoyancy, g ##### Extend defVar to be able to write fields to NetCDF directly ##### + +function squeeze_reduced_dimensions(data, reduced_dims) + # Fill missing indices from the right with 1s + indices = Any[:, :, :] + for i in 1:3 + if i ∈ reduced_dims + indices[i] = 1 + end + end + return getindex(data, indices...) +end + defVar(ds, name, op::AbstractOperation; kwargs...) = defVar(ds, name, Field(op); kwargs...) defVar(ds, name, op::Reduction; kwargs...) = defVar(ds, name, Field(op); kwargs...) function defVar(ds, name, field::AbstractField; time_dependent=false, + with_halos=false, dimension_name_generator = trilocation_dim_name, kwargs...) field_cpu = on_architecture(CPU(), field) # Need to bring field to CPU in order to write it to NetCDF - field_data = Array{eltype(field)}(field_cpu) + if with_halos + field_data = Array{eltype(field)}(parent(field_cpu)) + else + field_data = Array{eltype(field)}(interior(field_cpu)) + end dims = field_dimensions(field, dimension_name_generator) all_dims = time_dependent ? (dims..., "time") : dims # Validate that all dimensions exist and match the field - create_field_dimensions!(ds, field, all_dims, dimension_name_generator) - defVar(ds, name, field_data, all_dims; kwargs...) + create_field_dimensions!(ds, field, all_dims, dimension_name_generator; with_halos) + + squeezed_field_data = squeeze_reduced_dimensions(field_data, effective_reduced_dimensions(field)) + squeezed_reshaped_field_data = time_dependent ? reshape(squeezed_field_data, size(squeezed_field_data)..., 1) : squeezed_field_data + defVar(ds, name, squeezed_reshaped_field_data, all_dims; kwargs...) end ##### @@ -79,7 +99,7 @@ end ##### """ - create_field_dimensions!(ds, field::AbstractField, all_dims, dimension_name_generator) + create_field_dimensions!(ds, field::AbstractField, dim_names, dimension_name_generator) Creates all dimensions for the given `field` in the NetCDF dataset `ds`. If the dimensions already exist, they are validated to match the expected dimensions for the given `field`. @@ -87,18 +107,28 @@ already exist, they are validated to match the expected dimensions for the given Arguments: - `ds`: NetCDF dataset - `field`: AbstractField being written -- `all_dims`: Tuple of dimension names to create/validate +- `dim_names`: Tuple of dimension names to create/validate - `dimension_name_generator`: Function to generate dimension names """ -function create_field_dimensions!(ds, field::AbstractField, all_dims, dimension_name_generator) +function create_field_dimensions!(ds, field::AbstractField, dim_names, dimension_name_generator; with_halos=false) dimension_attributes = default_dimension_attributes(field.grid, dimension_name_generator) - spatial_dims = all_dims[1:end-(("time" in all_dims) ? 1 : 0)] + spatial_dim_names = dim_names[1:end-(("time" in dim_names) ? 1 : 0)] + + # Main.@infiltrate + # Get spatial dimensions excluding reduced dimensions (i.e. dimensions where `loc isa Nothing``) + reduced_dims = effective_reduced_dimensions(field) - spatial_dims_dict = Dict(dim_name => dim_data for (dim_name, dim_data) in zip(spatial_dims, nodes(field))) - create_spatial_dimensions!(ds, spatial_dims_dict, dimension_attributes; array_type=Array{eltype(field)}) + # At the moment, this returns the full nodes even when the field is sliced. + # https://github.com/CliMA/Oceananigans.jl/pull/4814 will fix this in the future. + node_data = nodes(field; with_halos) + spatial_dim_data = [data for (i, data) in enumerate(node_data) if i ∉ reduced_dims] + + # Create dictionary of spatial dimensions and their data + spatial_dim_names_dict = Dict(dim_name => dim_data for (dim_name, dim_data) in zip(spatial_dim_names, spatial_dim_data)) + create_spatial_dimensions!(ds, spatial_dim_names_dict, dimension_attributes; array_type=Array{eltype(field)}) # Create time dimension if needed - if "time" in all_dims && "time" ∉ keys(ds.dim) + if "time" in dim_names && "time" ∉ keys(ds.dim) create_time_dimension!(ds) end @@ -151,7 +181,7 @@ function create_spatial_dimensions!(dataset, dims, attributes_dict; array_type=A defVar(dataset, dim_name, array_type(dim_array), (dim_name,), attrib=attributes_dict[dim_name]; kwargs...) else # Validate existing dimension - if dataset[dim_name] != dim_array + if collect(dataset[dim_name]) != collect(dim_array) throw(ArgumentError("Dimension '$dim_name' already exists in dataset but is different from expected.\n" * " Actual: $(dataset[dim_name]) (length=$(length(dataset[dim_name])))\n" * " Expected: $(dim_array) (length=$(length(dim_array)))")) @@ -1236,7 +1266,8 @@ function initialize_nc_file(model, dimensions, filepath, # for better error messages dimension_name_generator, - false) # time_dependent = false + false, # time_dependent = false + with_halos) save_output!(dataset, output, model, name, array_type) end @@ -1255,7 +1286,8 @@ function initialize_nc_file(model, dimensions, filepath, # for better error messages dimension_name_generator, - true) # time_dependent = true) + true, # time_dependent = true) + with_halos) end sync(dataset) @@ -1293,7 +1325,7 @@ materialize_output(output::WindowedTimeAverage{<:AbstractField}, model) = output """ Defines empty variables for 'custom' user-supplied `output`. """ function define_output_variable!(dataset, output, name, array_type, deflatelevel, attrib, dimensions, filepath, - dimension_name_generator, time_dependent) + dimension_name_generator, time_dependent, with_halos) if name ∉ keys(dimensions) msg = string("dimensions[$name] for output $name=$(typeof(output)) into $filepath" * @@ -1311,15 +1343,9 @@ end """ Defines empty field variable. """ function define_output_variable!(dataset, output::AbstractField, name, array_type, deflatelevel, attrib, dimensions, filepath, - dimension_name_generator, time_dependent) - - dims = field_dimensions(output, dimension_name_generator) - FT = eltype(array_type) - - all_dims = time_dependent ? (dims..., "time") : dims - - defVar(dataset, name, FT, all_dims; deflatelevel, attrib) + dimension_name_generator, time_dependent, with_halos) + defVar(dataset, name, output; time_dependent, with_halos, deflatelevel, attrib) return nothing end