Skip to content

Commit ed90f64

Browse files
authored
Generalize assumed_field_location, plus allow FourierTridiagonalPoissonSolver on regular grid (#4535)
* Generalize assumed_field_location * updates to support anelastic models * generalize constructor to accomodate different boundary conditions / formulations * update * fix for stretched cases * update interface to prssures corrections * fixes * try to fix distributed solver * fix * fixes * clean up
1 parent 267cb45 commit ed90f64

15 files changed

+236
-168
lines changed

ext/OceananigansReactantExt/TimeSteppers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Oceananigans.Utils: @apply_regionally, apply_regionally!
1212
using Oceananigans.TimeSteppers:
1313
update_state!,
1414
tick!,
15-
calculate_pressure_correction!,
15+
compute_pressure_correction!,
1616
correct_velocities_and_cache_previous_tendencies!,
1717
step_lagrangian_particles!,
1818
QuasiAdamsBashforth2TimeStepper
@@ -103,7 +103,7 @@ function time_step!(model::ReactantModel{<:QuasiAdamsBashforth2TimeStepper{FT}},
103103
model.clock.last_stage_Δt = Δt
104104
end
105105

106-
calculate_pressure_correction!(model, Δt)
106+
compute_pressure_correction!(model, Δt)
107107
correct_velocities_and_cache_previous_tendencies!(model, Δt)
108108

109109
update_state!(model, callbacks; compute_tendencies=true)

src/DistributedComputations/distributed_fft_tridiagonal_solver.jl

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ using Oceananigans.Operators: Δxᶠᵃᵃ, Δyᵃᶠᵃ, Δzᵃᵃᶠ
55

66
using Oceananigans.Solvers: BatchedTridiagonalSolver,
77
stretched_direction,
8+
tridiagonal_direction,
9+
dimension,
10+
HomogeneousNeumannFormulation,
811
ZTridiagonalSolver,
912
YTridiagonalSolver,
1013
XTridiagonalSolver,
11-
compute_main_diagonal!
14+
compute_main_diagonal!,
15+
compute_lower_diagonal!
1216

1317
struct DistributedFourierTridiagonalPoissonSolver{G, L, B, P, R, S, β}
1418
plan :: P
@@ -146,30 +150,34 @@ Restrictions
146150
- Same as for two-dimensional decompositions with `Rx` (or `Ry`) equal to one
147151
148152
"""
149-
function DistributedFourierTridiagonalPoissonSolver(global_grid, local_grid, planner_flag=FFTW.PATIENT; tridiagonal_direction = nothing)
153+
function DistributedFourierTridiagonalPoissonSolver(global_grid, local_grid, planner_flag=FFTW.PATIENT; tridiagonal_formulation=nothing)
150154

151155
validate_poisson_solver_distributed_grid(global_grid)
152156
validate_poisson_solver_configuration(global_grid, local_grid)
153157

154-
if isnothing(tridiagonal_direction)
155-
tridiagonal_dim = stretched_dimensions(local_grid)[1]
156-
tridiagonal_direction = stretched_direction(local_grid)
158+
# Try to guess what direction should be tridiagonal
159+
if isnothing(tridiagonal_formulation)
160+
tridiagonal_dir = global_grid isa XYZRegularRG ? ZDirection() : stretched_direction(global_grid)
161+
tridiagonal_formulation = HomogeneousNeumannFormulation(tridiagonal_dir)
157162
else
158-
tridiagonal_dim = tridiagonal_direction == XDirection() ? 1 :
159-
tridiagonal_direction == YDirection() ? 2 : 3
163+
tridiagonal_dir = tridiagonal_direction(tridiagonal_formulation)
160164
end
161165

166+
tridiagonal_dim = dimension(tridiagonal_dir)
167+
162168
topology(global_grid, tridiagonal_dim) != Bounded &&
163169
error("`DistributedFourierTridiagonalPoissonSolver` requires that the stretched direction (dimension $tridiagonal_dim) is `Bounded`.")
164170

165-
FT = Complex{eltype(local_grid)}
171+
T = Complex{eltype(local_grid)}
166172
child_arch = child_architecture(local_grid)
167-
storage = TransposableField(CenterField(local_grid), FT)
173+
storage_field = CenterField(local_grid)
174+
storage = TransposableField(storage_field, T)
168175

169-
topo = (TX, TY, TZ) = topology(global_grid)
170-
λx = dropdims(poisson_eigenvalues(global_grid.Nx, global_grid.Lx, 1, TX()), dims=(2, 3))
171-
λy = dropdims(poisson_eigenvalues(global_grid.Ny, global_grid.Ly, 2, TY()), dims=(1, 3))
172-
λz = dropdims(poisson_eigenvalues(global_grid.Nz, global_grid.Lz, 3, TZ()), dims=(1, 2))
176+
TX, TY, TZ = topology(global_grid)
177+
tx, ty, tz = TX(), TY(), TZ()
178+
λx = dropdims(poisson_eigenvalues(global_grid.Nx, global_grid.Lx, 1, tx), dims=(2, 3))
179+
λy = dropdims(poisson_eigenvalues(global_grid.Ny, global_grid.Ly, 2, ty), dims=(1, 3))
180+
λz = dropdims(poisson_eigenvalues(global_grid.Nz, global_grid.Lz, 3, tz), dims=(1, 2))
173181

174182
if tridiagonal_dim == 1
175183
arch = architecture(storage.xfield.grid)
@@ -191,39 +199,31 @@ function DistributedFourierTridiagonalPoissonSolver(global_grid, local_grid, pla
191199
λ1 = on_architecture(child_arch, λ1)
192200
λ2 = on_architecture(child_arch, λ2)
193201

194-
plan = plan_distributed_transforms(global_grid, storage, planner_flag)
195-
196202
# Lower and upper diagonals are the same
197-
lower_diagonal = @allowscalar [ 1 / Δξᶠ(q, grid, Val(tridiagonal_dim)) for q in 2:size(grid, tridiagonal_dim) ]
198-
lower_diagonal = on_architecture(child_arch, lower_diagonal)
199-
upper_diagonal = lower_diagonal
203+
main_diagonal = zeros(grid, size(grid)...)
200204

201-
# Compute diagonal coefficients for each grid point
202-
diagonal = zeros(eltype(grid), size(grid)...)
203-
diagonal = on_architecture(arch, diagonal)
204-
launch_config = if tridiagonal_dim == 1
205-
:yz
206-
elseif tridiagonal_dim == 2
207-
:xz
208-
elseif tridiagonal_dim == 3
209-
:xy
210-
end
205+
Nd = size(grid, tridiagonal_dim) - 1
206+
lower_diagonal = zeros(grid, Nd)
207+
upper_diagonal = lower_diagonal
211208

212-
launch!(arch, grid, launch_config, compute_main_diagonal!, diagonal, grid, λ1, λ2, tridiagonal_direction)
209+
compute_main_diagonal!(main_diagonal, tridiagonal_formulation, grid, λ1, λ2)
210+
Nd > 0 && compute_lower_diagonal!(lower_diagonal, tridiagonal_formulation, grid)
213211

214212
# Set up batched tridiagonal solver
215-
btsolver = BatchedTridiagonalSolver(grid; lower_diagonal, diagonal, upper_diagonal, tridiagonal_direction)
213+
btsolver = BatchedTridiagonalSolver(grid; lower_diagonal, upper_diagonal,
214+
diagonal = main_diagonal,
215+
tridiagonal_direction = tridiagonal_dir)
216216

217-
# We need to permute indices to apply bounded transforms on the GPU (r2r of r2c with twiddling)
217+
# We need to permute indices to apply bounded transforms on the GPU (r2r or r2c with twiddling)
218218
x_buffer_needed = child_arch isa GPU && TX == Bounded
219219
z_buffer_needed = child_arch isa GPU && TZ == Bounded
220220

221221
# We cannot really batch anything, so on GPUs we always have to permute indices in the y direction
222222
y_buffer_needed = child_arch isa GPU
223223

224-
buffer_x = x_buffer_needed ? on_architecture(child_arch, zeros(FT, size(storage.xfield)...)) : nothing
225-
buffer_y = y_buffer_needed ? on_architecture(child_arch, zeros(FT, size(storage.yfield)...)) : nothing
226-
buffer_z = z_buffer_needed ? on_architecture(child_arch, zeros(FT, size(storage.zfield)...)) : nothing
224+
buffer_x = x_buffer_needed ? on_architecture(child_arch, zeros(T, size(storage.xfield)...)) : nothing
225+
buffer_y = y_buffer_needed ? on_architecture(child_arch, zeros(T, size(storage.yfield)...)) : nothing
226+
buffer_z = z_buffer_needed ? on_architecture(child_arch, zeros(T, size(storage.zfield)...)) : nothing
227227

228228
buffer = if tridiagonal_dim == 1
229229
(; y = buffer_y, z = buffer_z)
@@ -233,6 +233,8 @@ function DistributedFourierTridiagonalPoissonSolver(global_grid, local_grid, pla
233233
(; x = buffer_x, y = buffer_y)
234234
end
235235

236+
plan = plan_distributed_transforms(global_grid, storage, planner_flag)
237+
236238
if tridiagonal_dim == 1
237239
forward = (y! = plan.forward.y!, z! = plan.forward.z!)
238240
backward = (y! = plan.backward.y!, z! = plan.backward.z!)

src/Models/HydrostaticFreeSurfaceModels/barotropic_pressure_correction.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
using .SplitExplicitFreeSurfaces: barotropic_split_explicit_corrector!
2-
import Oceananigans.TimeSteppers: calculate_pressure_correction!, pressure_correct_velocities!
2+
import Oceananigans.TimeSteppers: compute_pressure_correction!, make_pressure_correction!
33

4-
calculate_pressure_correction!(::HydrostaticFreeSurfaceModel, Δt) = nothing
4+
compute_pressure_correction!(::HydrostaticFreeSurfaceModel, Δt) = nothing
55

66
#####
77
##### Barotropic pressure correction for models with a free surface
88
#####
99

10-
pressure_correct_velocities!(model::HydrostaticFreeSurfaceModel, Δt; kwargs...) =
11-
pressure_correct_velocities!(model, model.free_surface, Δt; kwargs...)
10+
make_pressure_correction!(model::HydrostaticFreeSurfaceModel, Δt; kwargs...) =
11+
make_pressure_correction!(model, model.free_surface, Δt; kwargs...)
1212

1313
# Fallback
14-
pressure_correct_velocities!(model, free_surface, Δt; kwargs...) = nothing
14+
make_pressure_correction!(model, free_surface, Δt; kwargs...) = nothing
1515

1616
#####
1717
##### Barotropic pressure correction for models with an Implicit free surface
1818
#####
1919

20-
function pressure_correct_velocities!(model, ::ImplicitFreeSurface, Δt)
20+
function make_pressure_correction!(model, ::ImplicitFreeSurface, Δt)
2121

2222
launch!(model.architecture, model.grid, :xyz,
2323
_barotropic_pressure_correction!,
@@ -30,7 +30,7 @@ function pressure_correct_velocities!(model, ::ImplicitFreeSurface, Δt)
3030
return nothing
3131
end
3232

33-
function pressure_correct_velocities!(model, ::SplitExplicitFreeSurface, Δt)
33+
function make_pressure_correction!(model, ::SplitExplicitFreeSurface, Δt)
3434
u, v, _ = model.velocities
3535
grid = model.grid
3636
barotropic_split_explicit_corrector!(u, v, model.free_surface, grid)

src/Models/NonhydrostaticModels/compute_nonhydrostatic_tendencies.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ end
160160
@inbounds Gw[i, j, k] = w_velocity_tendency(i, j, k, grid, args...)
161161
end
162162

163-
164163
#####
165164
##### Tracer(s)
166165
#####

src/Models/NonhydrostaticModels/nonhydrostatic_model.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ function NonhydrostaticModel(; grid,
222222
model_fields = merge(velocities, tracers, auxiliary_fields)
223223
prognostic_fields = merge(velocities, tracers)
224224

225-
226225
# Instantiate timestepper if not already instantiated
227226
implicit_solver = implicit_diffusion_solver(time_discretization(closure), grid)
228227
timestepper = TimeStepper(timestepper, grid, prognostic_fields; implicit_solver=implicit_solver)

src/Models/NonhydrostaticModels/pressure_correction.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import Oceananigans.TimeSteppers: calculate_pressure_correction!, pressure_correct_velocities!
1+
import Oceananigans.TimeSteppers: compute_pressure_correction!, make_pressure_correction!
22

33
"""
4-
calculate_pressure_correction!(model::NonhydrostaticModel, Δt)
4+
compute_pressure_correction!(model::NonhydrostaticModel, Δt)
55
66
Calculate the (nonhydrostatic) pressure correction associated `tendencies`, `velocities`, and step size `Δt`.
77
"""
8-
function calculate_pressure_correction!(model::NonhydrostaticModel, Δt)
8+
function compute_pressure_correction!(model::NonhydrostaticModel, Δt)
99

1010
# Mask immersed velocities
1111
foreach(mask_immersed_field!, model.velocities)
12-
1312
fill_halo_regions!(model.velocities, model.clock, fields(model))
14-
1513
solve_for_pressure!(model.pressures.pNHS, model.pressure_solver, Δt, model.velocities)
16-
1714
fill_halo_regions!(model.pressures.pNHS)
1815

1916
return nothing
@@ -28,7 +25,7 @@ Update the predictor velocities u, v, and w with the non-hydrostatic pressure mu
2825
2926
`u^{n+1} = u^n - δₓp_{NH} * Δt / Δx`
3027
"""
31-
@kernel function _pressure_correct_velocities!(U, grid, pNHSΔt)
28+
@kernel function _make_pressure_correction!(U, grid, pNHSΔt)
3229
i, j, k = @index(Global, NTuple)
3330

3431
@inbounds U.u[i, j, k] -= ∂xᶠᶜᶜ(i, j, k, grid, pNHSΔt)
@@ -37,10 +34,10 @@ Update the predictor velocities u, v, and w with the non-hydrostatic pressure mu
3734
end
3835

3936
"Update the solution variables (velocities and tracers)."
40-
function pressure_correct_velocities!(model::NonhydrostaticModel, Δt)
37+
function make_pressure_correction!(model::NonhydrostaticModel, Δt)
4138

4239
launch!(model.architecture, model.grid, :xyz,
43-
_pressure_correct_velocities!,
40+
_make_pressure_correction!,
4441
model.velocities,
4542
model.grid,
4643
model.pressures.pNHS)

src/Models/NonhydrostaticModels/set_nonhydrostatic_model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Oceananigans.BoundaryConditions: fill_halo_regions!
2-
using Oceananigans.TimeSteppers: update_state!, calculate_pressure_correction!, pressure_correct_velocities!
2+
using Oceananigans.TimeSteppers: update_state!, compute_pressure_correction!, make_pressure_correction!
33

44
import Oceananigans.Fields: set!
55

@@ -51,8 +51,8 @@ function set!(model::NonhydrostaticModel; enforce_incompressibility=true, kwargs
5151

5252
if enforce_incompressibility
5353
FT = eltype(model.grid)
54-
calculate_pressure_correction!(model, one(FT))
55-
pressure_correct_velocities!(model, one(FT))
54+
compute_pressure_correction!(model, one(FT))
55+
make_pressure_correction!(model, one(FT))
5656
update_state!(model; compute_tendencies = false)
5757
end
5858

src/Models/NonhydrostaticModels/solve_for_pressure.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,84 +9,84 @@ using Oceananigans.Solvers: solve!
99
##### Calculate the right-hand-side of the non-hydrostatic pressure Poisson equation.
1010
#####
1111

12-
@kernel function _compute_source_term!(rhs, grid, Δt, Ũ)
12+
@kernel function _compute_source_term!(rhs, grid, Ũ)
1313
i, j, k = @index(Global, NTuple)
1414
active = !inactive_cell(i, j, k, grid)
15-
δ = divᶜᶜᶜ(i, j, k, grid, Ũ.u, Ũ.v, Ũ.w)
15+
u, v, w =
16+
δ = divᶜᶜᶜ(i, j, k, grid, u, v, w)
1617
@inbounds rhs[i, j, k] = active * δ
1718
end
1819

19-
@kernel function _fourier_tridiagonal_source_term!(rhs, ::XDirection, grid, Δt, Ũ)
20+
@kernel function _fourier_tridiagonal_source_term!(rhs, ::XDirection, grid, Ũ)
2021
i, j, k = @index(Global, NTuple)
2122
active = !inactive_cell(i, j, k, grid)
22-
δ = divᶜᶜᶜ(i, j, k, grid, Ũ.u, Ũ.v, Ũ.w)
23+
u, v, w =
24+
δ = divᶜᶜᶜ(i, j, k, grid, u, v, w)
2325
@inbounds rhs[i, j, k] = active * Δxᶜᶜᶜ(i, j, k, grid) * δ
2426
end
2527

26-
@kernel function _fourier_tridiagonal_source_term!(rhs, ::YDirection, grid, Δt, Ũ)
28+
@kernel function _fourier_tridiagonal_source_term!(rhs, ::YDirection, grid, Ũ)
2729
i, j, k = @index(Global, NTuple)
2830
active = !inactive_cell(i, j, k, grid)
29-
δ = divᶜᶜᶜ(i, j, k, grid, Ũ.u, Ũ.v, Ũ.w)
31+
u, v, w =
32+
δ = divᶜᶜᶜ(i, j, k, grid, u, v, w)
3033
@inbounds rhs[i, j, k] = active * Δyᶜᶜᶜ(i, j, k, grid) * δ
3134
end
3235

33-
@kernel function _fourier_tridiagonal_source_term!(rhs, ::ZDirection, grid, Δt, Ũ)
36+
@kernel function _fourier_tridiagonal_source_term!(rhs, ::ZDirection, grid, Ũ)
3437
i, j, k = @index(Global, NTuple)
3538
active = !inactive_cell(i, j, k, grid)
36-
δ = divᶜᶜᶜ(i, j, k, grid, Ũ.u, Ũ.v, Ũ.w)
39+
u, v, w =
40+
δ = divᶜᶜᶜ(i, j, k, grid, u, v, w)
3741
@inbounds rhs[i, j, k] = active * Δzᶜᶜᶜ(i, j, k, grid) * δ
3842
end
3943

40-
function compute_source_term!(pressure, solver::DistributedFFTBasedPoissonSolver, Δt, Ũ)
44+
function compute_source_term!(solver::DistributedFFTBasedPoissonSolver, Ũ)
4145
rhs = solver.storage.zfield
4246
arch = architecture(solver)
4347
grid = solver.local_grid
44-
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, Ũ)
48+
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Ũ)
4549
return nothing
4650
end
4751

48-
function compute_source_term!(pressure, solver::DistributedFourierTridiagonalPoissonSolver, Δt, Ũ)
52+
function compute_source_term!(solver::DistributedFourierTridiagonalPoissonSolver, Ũ)
4953
rhs = solver.storage.zfield
5054
arch = architecture(solver)
5155
grid = solver.local_grid
5256
tdir = solver.batched_tridiagonal_solver.tridiagonal_direction
53-
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!, rhs, tdir, grid, Δt, Ũ)
57+
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!, rhs, tdir, grid, Ũ)
5458
return nothing
5559
end
5660

57-
function compute_source_term!(pressure, solver::FourierTridiagonalPoissonSolver, Δt, Ũ)
61+
function compute_source_term!(solver::FourierTridiagonalPoissonSolver, Ũ)
5862
rhs = solver.source_term
5963
arch = architecture(solver)
6064
grid = solver.grid
6165
tdir = solver.batched_tridiagonal_solver.tridiagonal_direction
62-
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!, rhs, tdir, grid, Δt, Ũ)
66+
launch!(arch, grid, :xyz, _fourier_tridiagonal_source_term!, rhs, tdir, grid, Ũ)
6367
return nothing
6468
end
6569

66-
function compute_source_term!(pressure, solver::FFTBasedPoissonSolver, Δt, Ũ)
70+
function compute_source_term!(solver::FFTBasedPoissonSolver, Ũ)
6771
rhs = solver.storage
6872
arch = architecture(solver)
6973
grid = solver.grid
70-
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, Ũ)
74+
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Ũ)
7175
return nothing
7276
end
7377

7478
#####
7579
##### Solve for pressure
7680
#####
7781

78-
function solve_for_pressure!(pressure, solver, Δt, Ũ)
79-
ϵ = eps(eltype(pressure))
80-
Δt⁺ = max(ϵ, Δt)
81-
Δt★ = Δt⁺ * isfinite(Δt)
82-
pressure .*= Δt★
83-
84-
compute_source_term!(pressure, solver, Δt, Ũ)
82+
# Note that Δt is unused here.
83+
function solve_for_pressure!(pressure, solver, Δt, args...)
84+
compute_source_term!(solver, args...)
8585
solve!(pressure, solver)
8686
return pressure
8787
end
8888

89-
function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver, Δt, )
89+
function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver, Δt, args...)
9090
ϵ = eps(eltype(pressure))
9191
Δt⁺ = max(ϵ, Δt)
9292
Δt★ = Δt⁺ * isfinite(Δt)
@@ -95,7 +95,7 @@ function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver,
9595
rhs = solver.right_hand_side
9696
grid = solver.grid
9797
arch = architecture(grid)
98-
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, Δt, Ũ)
98+
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, args...)
9999
return solve!(pressure, solver.conjugate_gradient_solver, rhs)
100100
end
101101

0 commit comments

Comments
 (0)