Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions ext/OceananigansNCDatasetsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,48 +57,78 @@ const BuoyancyBoussinesqEOSModel = BuoyancyForce{<:BoussinesqSeawaterBuoyancy, g
##### Extend defVar to be able to write fields to NetCDF directly
#####


function squeeze_reduced_dimensions(data, reduced_dims)
# Fill missing indices from the right with 1s
indices = Any[:, :, :]
for i in 1:3
if i ∈ reduced_dims
indices[i] = 1
end
end
return getindex(data, indices...)
end

defVar(ds, name, op::AbstractOperation; kwargs...) = defVar(ds, name, Field(op); kwargs...)
defVar(ds, name, op::Reduction; kwargs...) = defVar(ds, name, Field(op); kwargs...)

function defVar(ds, name, field::AbstractField;
time_dependent=false,
with_halos=false,
dimension_name_generator = trilocation_dim_name,
kwargs...)
field_cpu = on_architecture(CPU(), field) # Need to bring field to CPU in order to write it to NetCDF
field_data = Array{eltype(field)}(field_cpu)
if with_halos
field_data = Array{eltype(field)}(parent(field_cpu))
else
field_data = Array{eltype(field)}(interior(field_cpu))
end
dims = field_dimensions(field, dimension_name_generator)
all_dims = time_dependent ? (dims..., "time") : dims

# Validate that all dimensions exist and match the field
create_field_dimensions!(ds, field, all_dims, dimension_name_generator)
defVar(ds, name, field_data, all_dims; kwargs...)
create_field_dimensions!(ds, field, all_dims, dimension_name_generator; with_halos)

squeezed_field_data = squeeze_reduced_dimensions(field_data, effective_reduced_dimensions(field))
squeezed_reshaped_field_data = time_dependent ? reshape(squeezed_field_data, size(squeezed_field_data)..., 1) : squeezed_field_data
defVar(ds, name, squeezed_reshaped_field_data, all_dims; kwargs...)
end

#####
##### Dimension validation
#####

"""
create_field_dimensions!(ds, field::AbstractField, all_dims, dimension_name_generator)
create_field_dimensions!(ds, field::AbstractField, dim_names, dimension_name_generator)

Creates all dimensions for the given `field` in the NetCDF dataset `ds`. If the dimensions
already exist, they are validated to match the expected dimensions for the given `field`.

Arguments:
- `ds`: NetCDF dataset
- `field`: AbstractField being written
- `all_dims`: Tuple of dimension names to create/validate
- `dim_names`: Tuple of dimension names to create/validate
- `dimension_name_generator`: Function to generate dimension names
"""
function create_field_dimensions!(ds, field::AbstractField, all_dims, dimension_name_generator)
function create_field_dimensions!(ds, field::AbstractField, dim_names, dimension_name_generator; with_halos=false)
dimension_attributes = default_dimension_attributes(field.grid, dimension_name_generator)
spatial_dims = all_dims[1:end-(("time" in all_dims) ? 1 : 0)]
spatial_dim_names = dim_names[1:end-(("time" in dim_names) ? 1 : 0)]

# Main.@infiltrate
# Get spatial dimensions excluding reduced dimensions (i.e. dimensions where `loc isa Nothing``)
reduced_dims = effective_reduced_dimensions(field)

spatial_dims_dict = Dict(dim_name => dim_data for (dim_name, dim_data) in zip(spatial_dims, nodes(field)))
create_spatial_dimensions!(ds, spatial_dims_dict, dimension_attributes; array_type=Array{eltype(field)})
# At the moment, this returns the full nodes even when the field is sliced.
# https://github.com/CliMA/Oceananigans.jl/pull/4814 will fix this in the future.
node_data = nodes(field; with_halos)
Comment on lines +121 to +123
Copy link
Collaborator Author

@tomchor tomchor Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to future self: revisit this after #4814 is merged. It should make this line work out of the box and tests pass

spatial_dim_data = [data for (i, data) in enumerate(node_data) if i ∉ reduced_dims]

# Create dictionary of spatial dimensions and their data
spatial_dim_names_dict = Dict(dim_name => dim_data for (dim_name, dim_data) in zip(spatial_dim_names, spatial_dim_data))
create_spatial_dimensions!(ds, spatial_dim_names_dict, dimension_attributes; array_type=Array{eltype(field)})

# Create time dimension if needed
if "time" in all_dims && "time" ∉ keys(ds.dim)
if "time" in dim_names && "time" ∉ keys(ds.dim)
create_time_dimension!(ds)
end

Expand Down Expand Up @@ -151,7 +181,7 @@ function create_spatial_dimensions!(dataset, dims, attributes_dict; array_type=A
defVar(dataset, dim_name, array_type(dim_array), (dim_name,), attrib=attributes_dict[dim_name]; kwargs...)
else
# Validate existing dimension
if dataset[dim_name] != dim_array
if collect(dataset[dim_name]) != collect(dim_array)
throw(ArgumentError("Dimension '$dim_name' already exists in dataset but is different from expected.\n" *
" Actual: $(dataset[dim_name]) (length=$(length(dataset[dim_name])))\n" *
" Expected: $(dim_array) (length=$(length(dim_array)))"))
Expand Down Expand Up @@ -1236,7 +1266,8 @@ function initialize_nc_file(model,
dimensions,
filepath, # for better error messages
dimension_name_generator,
false) # time_dependent = false
false, # time_dependent = false
with_halos)

save_output!(dataset, output, model, name, array_type)
end
Expand All @@ -1255,7 +1286,8 @@ function initialize_nc_file(model,
dimensions,
filepath, # for better error messages
dimension_name_generator,
true) # time_dependent = true)
true, # time_dependent = true)
with_halos)
end

sync(dataset)
Expand Down Expand Up @@ -1293,7 +1325,7 @@ materialize_output(output::WindowedTimeAverage{<:AbstractField}, model) = output
""" Defines empty variables for 'custom' user-supplied `output`. """
function define_output_variable!(dataset, output, name, array_type,
deflatelevel, attrib, dimensions, filepath,
dimension_name_generator, time_dependent)
dimension_name_generator, time_dependent, with_halos)

if name ∉ keys(dimensions)
msg = string("dimensions[$name] for output $name=$(typeof(output)) into $filepath" *
Expand All @@ -1311,15 +1343,9 @@ end
""" Defines empty field variable. """
function define_output_variable!(dataset, output::AbstractField, name, array_type,
deflatelevel, attrib, dimensions, filepath,
dimension_name_generator, time_dependent)

dims = field_dimensions(output, dimension_name_generator)
FT = eltype(array_type)

all_dims = time_dependent ? (dims..., "time") : dims

defVar(dataset, name, FT, all_dims; deflatelevel, attrib)
dimension_name_generator, time_dependent, with_halos)

defVar(dataset, name, output; time_dependent, with_halos, deflatelevel, attrib)
return nothing
end

Expand Down