Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Finally, note that the nodes of the staggered mesh coincide with the cell interf
znodes(grid, Center())

# output
4-element view(::Vector{Float64}, 2:5) with eltype Float64:
4-element Vector{Float64}:
0.05
0.2
0.44999999999999996
Expand Down
3 changes: 3 additions & 0 deletions ext/OceananigansReactantExt/OceananigansReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ end
return c
end

Base.getindex(array::OffsetVector{T, <:Reactant.AbstractConcreteArray{T, 1}}, ::Colon) where T = array


# These are additional modules that may need to be Reactantified in the future:
#
# include("Utils.jl")
Expand Down
6 changes: 6 additions & 0 deletions src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,9 @@ function fill_halo_regions!(field::Field, positional_args...; kwargs...)

return nothing
end

#####
##### nodes
#####

nodes(f::Field; kwargs...) = nodes(f.grid, instantiated_location(f)...; indices=indices(f), kwargs...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nodes(f::Field; kwargs...) = nodes(f.grid, instantiated_location(f)...; indices=indices(f), kwargs...)
nodes(f::Field; kwargs...) = nodes(f.grid, instantiated_location(f)...; indices=indices(f), kwargs...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm missing something; What's different here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the extra line at the bottom of the file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he added whitespace below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, okay. Are we trying to always leave an extra line at the bottom of files?

28 changes: 18 additions & 10 deletions src/Grids/latitude_longitude_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,10 @@ rname(::LLG) = :z
@inline xnode(i, j, k, grid::LLG, ℓx, ℓy, ℓz) = xnode(i, j, grid, ℓx, ℓy)
@inline ynode(i, j, k, grid::LLG, ℓx, ℓy, ℓz) = ynode(j, grid, ℓy)

function nodes(grid::LLG, ℓx, ℓy, ℓz; reshape=false, with_halos=false)
λ = λnodes(grid, ℓx, ℓy, ℓz; with_halos)
φ = φnodes(grid, ℓx, ℓy, ℓz; with_halos)
z = znodes(grid, ℓx, ℓy, ℓz; with_halos)
function nodes(grid::LLG, ℓx, ℓy, ℓz; reshape=false, with_halos=false, indices=(Colon(), Colon(), Colon()))
λ = λnodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[1])
φ = φnodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[2])
z = znodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[3])

if reshape
# Here we have to deal with the fact that Flat directions may have
Expand Down Expand Up @@ -647,15 +647,23 @@ end
end

# Convenience
@inline λnodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false) = λnodes(grid, ℓx; with_halos)
@inline φnodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false) = φnodes(grid, ℓy; with_halos)
@inline λnodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false, indices=Colon()) = λnodes(grid, ℓx; with_halos, indices)
@inline φnodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false, indices=Colon()) = φnodes(grid, ℓy; with_halos, indices)
@inline xnodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false) = xnodes(grid, ℓx, ℓy; with_halos)
@inline ynodes(grid::LLG, ℓx, ℓy, ℓz; with_halos=false) = ynodes(grid, ℓy; with_halos)

@inline λnodes(grid::LLG, ℓx::F; with_halos=false) = _property(grid.λᶠᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline λnodes(grid::LLG, ℓx::C; with_halos=false) = _property(grid.λᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline φnodes(grid::LLG, ℓy::F; with_halos=false) = _property(grid.φᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline φnodes(grid::LLG, ℓy::C; with_halos=false) = _property(grid.φᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline λnodes(grid::LLG, ℓx::F; with_halos=false, indices=Colon()) = getindex(_property(grid.λᶠᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos), indices)
Copy link
Collaborator

@simone-silvestri simone-silvestri Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will not work on GPU right? Probably it will return a scalar indexing issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's why GPU tests are failing. I'll try to come up with an alternative soon

@inline λnodes(grid::LLG, ℓx::C; with_halos=false, indices=Colon()) = getindex(_property(grid.λᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos), indices)
@inline φnodes(grid::LLG, ℓy::F; with_halos=false, indices=Colon()) = getindex(_property(grid.φᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos), indices)
@inline φnodes(grid::LLG, ℓy::C; with_halos=false, indices=Colon()) = getindex(_property(grid.φᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos), indices)

# Flat topologies
XFlatLLG = LatitudeLongitudeGrid{<:Any, Flat}
YFlatLLG = LatitudeLongitudeGrid{<:Any, <:Any, Flat}
@inline λnodes(grid::XFlatLLG, ℓx::F; with_halos=false, indices=Colon()) = _property(grid.λᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline λnodes(grid::XFlatLLG, ℓx::C; with_halos=false, indices=Colon()) = _property(grid.λᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline φnodes(grid::YFlatLLG, ℓy::F; with_halos=false, indices=Colon()) = _property(grid.φᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline φnodes(grid::YFlatLLG, ℓy::C; with_halos=false, indices=Colon()) = _property(grid.φᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)

# Generalized coordinates
@inline ξnodes(grid::LLG, ℓx; kwargs...) = λnodes(grid, ℓx; kwargs...)
Expand Down
66 changes: 33 additions & 33 deletions src/Grids/orthogonal_spherical_shell_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,12 @@ end

"""
OrthogonalSphericalShellGrid(arch = CPU(), FT = Oceananigans.defaults.FloatType;
size,
z,
radius = R_Earth,
conformal_mapping = nothing,
halo = (3, 3, 3),
topology = (Bounded, Bounded, Bounded))
size,
z,
radius = R_Earth,
conformal_mapping = nothing,
halo = (3, 3, 3),
topology = (Bounded, Bounded, Bounded))

Return an OrthogonalSphericalShellGrid with empty horizontal metrics.
"""
Expand Down Expand Up @@ -528,10 +528,10 @@ function Base.show(io::IO, grid::OrthogonalSphericalShellGrid, withsummary=true)
"└── ", z_summary)
end

function nodes(grid::OSSG, ℓx, ℓy, ℓz; reshape=false, with_halos=false)
λ = λnodes(grid, ℓx, ℓy, ℓz; with_halos)
φ = φnodes(grid, ℓx, ℓy, ℓz; with_halos)
z = znodes(grid, ℓx, ℓy, ℓz; with_halos)
function nodes(grid::OSSG, ℓx, ℓy, ℓz; reshape=false, with_halos=false, indices=(Colon(), Colon(), Colon()))
λ = λnodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[1:2])
φ = φnodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[1:2])
z = znodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[3])

if reshape
# λ and φ are 2D arrays
Expand All @@ -544,40 +544,40 @@ function nodes(grid::OSSG, ℓx, ℓy, ℓz; reshape=false, with_halos=false)
return (λ, φ, z)
end

@inline λnodes(grid::OSSG, ℓx::F, ℓy::F; with_halos=false) =
_property(grid.λᶠᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline λnodes(grid::OSSG, ℓx::F, ℓy::F; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.λᶠᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline λnodes(grid::OSSG, ℓx::F, ℓy::C; with_halos=false) =
_property(grid.λᶠᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline λnodes(grid::OSSG, ℓx::F, ℓy::C; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.λᶠᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline λnodes(grid::OSSG, ℓx::C, ℓy::F; with_halos=false) =
_property(grid.λᶜᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline λnodes(grid::OSSG, ℓx::C, ℓy::F; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.λᶜᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline λnodes(grid::OSSG, ℓx::C, ℓy::C; with_halos=false) =
_property(grid.λᶜᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline λnodes(grid::OSSG, ℓx::C, ℓy::C; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.λᶜᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline φnodes(grid::OSSG, ℓx::F, ℓy::F; with_halos=false) =
_property(grid.φᶠᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline φnodes(grid::OSSG, ℓx::F, ℓy::F; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.φᶠᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline φnodes(grid::OSSG, ℓx::F, ℓy::C; with_halos=false) =
_property(grid.φᶠᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline φnodes(grid::OSSG, ℓx::F, ℓy::C; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.φᶠᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline φnodes(grid::OSSG, ℓx::C, ℓy::F, ; with_halos=false) =
_property(grid.φᶠᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline φnodes(grid::OSSG, ℓx::C, ℓy::F; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.φᶜᶠᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline φnodes(grid::OSSG, ℓx::C, ℓy::C; with_halos=false) =
_property(grid.φᶜᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos)
@inline φnodes(grid::OSSG, ℓx::C, ℓy::C; with_halos=false, indices=(Colon(), Colon())) =
getindex(_property(grid.φᶜᶜᵃ, ℓx, ℓy, topology(grid, 1), topology(grid, 2), grid.Nx, grid.Ny, grid.Hx, grid.Hy, with_halos), indices...)

@inline xnodes(grid::OSSG, ℓx, ℓy; with_halos=false) =
grid.radius * deg2rad.(λnodes(grid, ℓx, ℓy; with_halos)) .* hack_cosd.(φnodes(grid, ℓx, ℓy; with_halos))
@inline xnodes(grid::OSSG, ℓx, ℓy; with_halos=false, indices=(Colon(), Colon())) =
grid.radius * deg2rad.(λnodes(grid, ℓx, ℓy; with_halos, indices)) .* hack_cosd.(φnodes(grid, ℓx, ℓy; with_halos, indices))

@inline ynodes(grid::OSSG, ℓx, ℓy; with_halos=false) = grid.radius * deg2rad.(φnodes(grid, ℓx, ℓy; with_halos))
@inline ynodes(grid::OSSG, ℓx, ℓy; with_halos=false, indices=(Colon(), Colon())) = grid.radius * deg2rad.(φnodes(grid, ℓx, ℓy; with_halos, indices))

# convenience
@inline λnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false) = λnodes(grid, ℓx, ℓy; with_halos)
@inline φnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false) = φnodes(grid, ℓx, ℓy; with_halos)
@inline xnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false) = xnodes(grid, ℓx, ℓy; with_halos)
@inline ynodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false) = ynodes(grid, ℓx, ℓy; with_halos)
@inline λnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false, indices=(Colon(), Colon())) = λnodes(grid, ℓx, ℓy; with_halos, indices)
@inline φnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false, indices=(Colon(), Colon())) = φnodes(grid, ℓx, ℓy; with_halos, indices)
@inline xnodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false, indices=(Colon(), Colon())) = xnodes(grid, ℓx, ℓy; with_halos, indices)
@inline ynodes(grid::OSSG, ℓx, ℓy, ℓz; with_halos=false, indices=(Colon(), Colon())) = ynodes(grid, ℓx, ℓy; with_halos, indices)

@inline λnode(i, j, grid::OSSG, ::Center, ::Center) = @inbounds grid.λᶜᶜᵃ[i, j]
@inline λnode(i, j, grid::OSSG, ::Face , ::Center) = @inbounds grid.λᶠᶜᵃ[i, j]
Expand Down
29 changes: 19 additions & 10 deletions src/Grids/rectilinear_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,10 @@ rname(::RG) = :z
@inline xnode(i, j, k, grid::RG, ℓx, ℓy, ℓz) = xnode(i, grid, ℓx)
@inline ynode(i, j, k, grid::RG, ℓx, ℓy, ℓz) = ynode(j, grid, ℓy)

function nodes(grid::RectilinearGrid, ℓx, ℓy, ℓz; reshape=false, with_halos=false)
x = xnodes(grid, ℓx, ℓy, ℓz; with_halos)
y = ynodes(grid, ℓx, ℓy, ℓz; with_halos)
z = znodes(grid, ℓx, ℓy, ℓz; with_halos)
function nodes(grid::RectilinearGrid, ℓx, ℓy, ℓz; reshape=false, with_halos=false, indices=(Colon(), Colon(), Colon()))
x = xnodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[1])
y = ynodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[2])
z = znodes(grid, ℓx, ℓy, ℓz; with_halos, indices = indices[3])

if reshape
# Here we have to deal with the fact that Flat directions may have
Expand All @@ -511,14 +511,23 @@ function nodes(grid::RectilinearGrid, ℓx, ℓy, ℓz; reshape=false, with_halo
return (x, y, z)
end

@inline xnodes(grid::RG, ℓx::F; with_halos=false) = _property(grid.xᶠᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline xnodes(grid::RG, ℓx::C; with_halos=false) = _property(grid.xᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline ynodes(grid::RG, ℓy::F; with_halos=false) = _property(grid.yᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline ynodes(grid::RG, ℓy::C; with_halos=false) = _property(grid.yᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline xnodes(grid::RG, ℓx::F; with_halos=false, indices=Colon()) = getindex(_property(grid.xᶠᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos), indices)
@inline xnodes(grid::RG, ℓx::C; with_halos=false, indices=Colon()) = getindex(_property(grid.xᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos), indices)
@inline ynodes(grid::RG, ℓy::F; with_halos=false, indices=Colon()) = getindex(_property(grid.yᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos), indices)
@inline ynodes(grid::RG, ℓy::C; with_halos=false, indices=Colon()) = getindex(_property(grid.yᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos), indices)

# convenience
@inline xnodes(grid::RG, ℓx, ℓy, ℓz; with_halos=false) = xnodes(grid, ℓx; with_halos)
@inline ynodes(grid::RG, ℓx, ℓy, ℓz; with_halos=false) = ynodes(grid, ℓy; with_halos)
@inline xnodes(grid::RG, ℓx, ℓy, ℓz; with_halos=false, indices=Colon()) = xnodes(grid, ℓx; with_halos, indices)
@inline ynodes(grid::RG, ℓx, ℓy, ℓz; with_halos=false, indices=Colon()) = ynodes(grid, ℓy; with_halos, indices)

# Flat topologies
XFlatRG = RectilinearGrid{<:Any, Flat}
YFlatRG = RectilinearGrid{<:Any, <:Any, Flat}
ZFlatRG = RectilinearGrid{<:Any, <:Any, <:Any, Flat}
@inline xnodes(grid::XFlatRG, ℓx::F; with_halos=false, indices=Colon()) = _property(grid.xᶠᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline xnodes(grid::XFlatRG, ℓx::C; with_halos=false, indices=Colon()) = _property(grid.xᶜᵃᵃ, ℓx, topology(grid, 1), grid.Nx, grid.Hx, with_halos)
@inline ynodes(grid::YFlatRG, ℓy::F; with_halos=false, indices=Colon()) = _property(grid.yᵃᶠᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)
@inline ynodes(grid::YFlatRG, ℓy::C; with_halos=false, indices=Colon()) = _property(grid.yᵃᶜᵃ, ℓy, topology(grid, 2), grid.Ny, grid.Hy, with_halos)

# Generalized coordinates
@inline ξnodes(grid::RG, ℓx; kwargs...) = xnodes(grid, ℓx; kwargs...)
Expand Down
14 changes: 9 additions & 5 deletions src/Grids/vertical_discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,16 @@ end
@inline znode(k, grid, ℓz) = rnode(k, grid, ℓz)
@inline znode(i, j, k, grid, ℓx, ℓy, ℓz) = rnode(i, j, k, grid, ℓx, ℓy, ℓz)

@inline rnodes(grid::AUG, ℓz::F; with_halos=false) = _property(grid.z.cᵃᵃᶠ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos)
@inline rnodes(grid::AUG, ℓz::C; with_halos=false) = _property(grid.z.cᵃᵃᶜ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos)
@inline rnodes(grid::AUG, ℓx, ℓy, ℓz; with_halos=false) = rnodes(grid, ℓz; with_halos)
@inline rnodes(grid::AUG, ℓz::F; with_halos=false, indices=Colon()) = getindex(_property(grid.z.cᵃᵃᶠ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos), indices)
@inline rnodes(grid::AUG, ℓz::C; with_halos=false, indices=Colon()) = getindex(_property(grid.z.cᵃᵃᶜ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos), indices)
@inline rnodes(grid::AUG, ℓx, ℓy, ℓz; with_halos=false, indices=Colon()) = rnodes(grid, ℓz; with_halos, indices)

rnodes(grid::AUG, ::Nothing; kwargs...) = 1:1
znodes(grid::AUG, ::Nothing; kwargs...) = 1:1
@inline rnodes(grid::AUG, ::Nothing; kwargs...) = 1:1
@inline znodes(grid::AUG, ::Nothing; kwargs...) = 1:1

ZFlatAUG = AbstractUnderlyingGrid{<:Any, <:Any, <:Any, Flat}
@inline rnodes(grid::ZFlatAUG, ℓz::F; with_halos=false, indices=Colon()) = _property(grid.z.cᵃᵃᶠ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos)
@inline rnodes(grid::ZFlatAUG, ℓz::C; with_halos=false, indices=Colon()) = _property(grid.z.cᵃᵃᶜ, ℓz, topology(grid, 3), grid.Nz, grid.Hz, with_halos)

# TODO: extend in the Operators module
@inline znodes(grid::AUG, ℓz; kwargs...) = rnodes(grid, ℓz; kwargs...)
Expand Down
76 changes: 76 additions & 0 deletions test/test_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,54 @@ function run_field_interpolation_tests(grid)
return nothing
end

function nodes_of_field_views_are_consistent(grid)
# Test with different field types
test_fields = [CenterField(grid), XFaceField(grid), YFaceField(grid), ZFaceField(grid)]

for field in test_fields
loc = instantiated_location(field)

# Test various view patterns
test_indices = [
(2:6, :, :), # x slice
(:, 2:4, :), # y slice
(:, :, 2:3), # z slice
(3:5, 2:4, :), # xy slice
(2:6, :, 2:3), # xz slice
(:, 2:4, 2:3), # yz slice
(3:5, 2:4, 2:3), # xyz slice
]

for test_idx in test_indices
# Create field view with these indices
field_view = view(field, test_idx...)

# Get nodes from the view
view_nodes = nodes(field_view)

# Get nodes from the original field with the same indices
# This is what should be equivalent to the view_nodes
full_nodes = nodes(field.grid, loc...; indices=test_idx)

# Test that they are equal
@test view_nodes == full_nodes

# Also test that the view's indices match what we expect
@test indices(field_view) == test_idx

# Test that view nodes have sizes consistent with the view indices
for (i, coord_nodes) in enumerate(view_nodes)
if coord_nodes !== nothing && full_nodes[i] !== nothing
@test coord_nodes == full_nodes[i]
end
end
end
end

return nothing
end


#####
#####
#####
Expand Down Expand Up @@ -655,4 +703,32 @@ end
@test_throws BoundsError cvvv[:, :, k_top-2:k_top]
end
end

@testset "Field nodes and view consistency" begin
@info " Testing that nodes() returns indices consistent with view()..."
for arch in archs, FT in float_types
# Test RectilinearGrid
rectilinear_grid = RectilinearGrid(arch, FT, size=(8, 6, 4), extent=(2, 3, 1))
nodes_of_field_views_are_consistent(rectilinear_grid)

# Test LatitudeLongitudeGrid
latlon_grid = LatitudeLongitudeGrid(arch, FT, size=(8, 6, 4), longitude = (-180, 180), latitude = (-85, 85), z = (-100, 0))
nodes_of_field_views_are_consistent(latlon_grid)

# Test OrthogonalSphericalShellGrid (TripolarGrid)
tripolar_grid = TripolarGrid(arch, FT, size=(8, 6, 4))
nodes_of_field_views_are_consistent(tripolar_grid)

# Test Flat topology behavior for RectilinearGrid
flat_rlgrid = RectilinearGrid(arch, FT, size=(), extent=(), topology=(Flat, Flat, Flat))
c_flat = CenterField(flat_rlgrid)
@test nodes(c_flat) == (nothing, nothing, nothing)

# Test Flat topology behavior for LatitudeLongitudeGrid
flat_llgrid = LatitudeLongitudeGrid(arch, FT, size=(), topology=(Flat, Flat, Flat))
c_flat = CenterField(flat_llgrid)
@test nodes(c_flat) == (nothing, nothing, nothing)
end
end
end