Skip to content

Commit fc6bbd7

Browse files
xkykaitomchorsimone-silvestri
authored
Symmetrizing Laplacian in order to use conjugate gradient in nonuniform grids (#4563)
Co-authored-by: Tomás Chor <[email protected]> Co-authored-by: Simone Silvestri <[email protected]>
1 parent ff6d022 commit fc6bbd7

File tree

5 files changed

+326
-43
lines changed

5 files changed

+326
-43
lines changed

src/Models/NonhydrostaticModels/solve_for_pressure.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ end
4141
@inbounds rhs[i, j, k] = active * Δzᶜᶜᶜ(i, j, k, grid) * δ
4242
end
4343

44+
@kernel function _cg_source_term!(rhs, grid, Ũ)
45+
i, j, k = @index(Global, NTuple)
46+
active = !inactive_cell(i, j, k, grid)
47+
δ = divᶜᶜᶜ(i, j, k, grid, Ũ.u, Ũ.v, Ũ.w)
48+
V = Vᶜᶜᶜ(i, j, k, grid)
49+
@inbounds rhs[i, j, k] = active * δ * V
50+
end
51+
4452
function compute_source_term!(solver::DistributedFFTBasedPoissonSolver, Ũ)
4553
rhs = solver.storage.zfield
4654
arch = architecture(solver)
@@ -95,7 +103,6 @@ function solve_for_pressure!(pressure, solver::ConjugateGradientPoissonSolver,
95103
rhs = solver.right_hand_side
96104
grid = solver.grid
97105
arch = architecture(grid)
98-
launch!(arch, grid, :xyz, _compute_source_term!, rhs, grid, args...)
106+
launch!(arch, grid, :xyz, _cg_source_term!, rhs, grid, args...)
99107
return solve!(pressure, solver.conjugate_gradient_solver, rhs)
100108
end
101-

src/Solvers/conjugate_gradient_poisson_solver.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Oceananigans.Operators
2+
using Oceananigans.Operators: Ax_∂xᶠᵃᵃ, Ay_∂yᵃᶠᵃ, Az_∂zᵃᵃᶠ, ∇²ᶜᶜᶜ, Vᶜᶜᶜ
23
using Oceananigans.ImmersedBoundaries: ImmersedBoundaryGrid
34
using Statistics: mean
45

@@ -30,23 +31,29 @@ function Base.show(io::IO, ips::ConjugateGradientPoissonSolver)
3031
" └── iteration: ", prettysummary(ips.conjugate_gradient_solver.iteration))
3132
end
3233

33-
@kernel function laplacian!(∇²ϕ, grid, ϕ)
34+
@inline function V∇²ᶜᶜᶜ(i, j, k, grid, c)
35+
return δxᶜᶜᶜ(i, j, k, grid, Ax_∂xᶠᶜᶜ, c) +
36+
δyᶜᶜᶜ(i, j, k, grid, Ay_∂yᶜᶠᶜ, c) +
37+
δzᶜᶜᶜ(i, j, k, grid, Az_∂zᶜᶜᶠ, c)
38+
end
39+
40+
@kernel function _symmetric_laplacian_operator!(∇²ϕ, grid, ϕ)
3441
i, j, k = @index(Global, NTuple)
35-
@inbounds ∇²ϕ[i, j, k] = ∇²ᶜᶜᶜ(i, j, k, grid, ϕ)
42+
@inbounds ∇²ϕ[i, j, k] = V∇²ᶜᶜᶜ(i, j, k, grid, ϕ)
3643
end
3744

38-
function compute_laplacian!(∇²ϕ, ϕ)
45+
function compute_symmetric_laplacian!(∇²ϕ, ϕ)
3946
grid = ϕ.grid
4047
arch = architecture(grid)
4148
fill_halo_regions!(ϕ)
42-
launch!(arch, grid, :xyz, laplacian!, ∇²ϕ, grid, ϕ)
49+
launch!(arch, grid, :xyz, _symmetric_laplacian_operator!, ∇²ϕ, grid, ϕ)
4350
return nothing
4451
end
4552

4653
@kernel function subtract_and_mask!(a, grid, b)
4754
i, j, k = @index(Global, NTuple)
4855
active = !inactive_cell(i, j, k, grid)
49-
a[i, j, k] = (a[i, j, k] - b) * active
56+
@inbounds a[i, j, k] = (a[i, j, k] - b) * active
5057
end
5158

5259
function enforce_zero_mean_gauge!(x, r)
@@ -60,6 +67,18 @@ function enforce_zero_mean_gauge!(x, r)
6067
launch!(arch, grid, :xyz, subtract_and_mask!, r, grid, mean_r)
6168
end
6269

70+
@kernel function cell_volume!(V, grid)
71+
i, j, k = @index(Global, NTuple)
72+
@inbounds V[i, j, k] = Vᶜᶜᶜ(i, j, k, grid)
73+
end
74+
75+
function minimum_cell_volume(grid)
76+
V = CenterField(grid)
77+
arch = architecture(grid)
78+
launch!(arch, grid, :xyz, cell_volume!, V, grid)
79+
return minimum(V)
80+
end
81+
6382
struct DefaultPreconditioner end
6483

6584
"""
@@ -81,8 +100,8 @@ is a common choice to remove this degree of freedom.
81100
"""
82101
function ConjugateGradientPoissonSolver(grid;
83102
preconditioner = DefaultPreconditioner(),
84-
reltol = sqrt(eps(grid)),
85-
abstol = sqrt(eps(grid)),
103+
reltol = min(100 * eps(grid), 100 * eps(grid) * minimum_cell_volume(grid)^2, sqrt(eps(grid))),
104+
abstol = min(100 * eps(grid), sqrt(eps(grid))),
86105
enforce_gauge_condition! = enforce_zero_mean_gauge!,
87106
kw...)
88107

@@ -96,7 +115,7 @@ function ConjugateGradientPoissonSolver(grid;
96115

97116
rhs = CenterField(grid)
98117

99-
conjugate_gradient_solver = ConjugateGradientSolver(compute_laplacian!;
118+
conjugate_gradient_solver = ConjugateGradientSolver(compute_symmetric_laplacian!;
100119
reltol,
101120
abstol,
102121
preconditioner,
@@ -111,30 +130,30 @@ end
111130
##### A preconditioner based on the FFT solver
112131
#####
113132

114-
@kernel function fft_preconditioner_rhs!(preconditioner_rhs, rhs)
133+
@kernel function fft_preconditioner_rhs!(preconditioner_rhs, rhs, grid)
115134
i, j, k = @index(Global, NTuple)
116-
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k]
135+
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k] * V⁻¹ᶜᶜᶜ(i, j, k, grid)
117136
end
118137

119138
@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::XDirection, grid, rhs)
120139
i, j, k = @index(Global, NTuple)
121-
@inbounds preconditioner_rhs[i, j, k] = Δxᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
140+
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k] * V⁻¹ᶜᶜᶜ(i, j, k, grid)
122141
end
123142

124143
@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::YDirection, grid, rhs)
125144
i, j, k = @index(Global, NTuple)
126-
@inbounds preconditioner_rhs[i, j, k] = Δyᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
145+
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k] * V⁻¹ᶜᶜᶜ(i, j, k, grid)
127146
end
128147

129148
@kernel function fourier_tridiagonal_preconditioner_rhs!(preconditioner_rhs, ::ZDirection, grid, rhs)
130149
i, j, k = @index(Global, NTuple)
131-
@inbounds preconditioner_rhs[i, j, k] = Δzᶜᶜᶜ(i, j, k, grid) * rhs[i, j, k]
150+
@inbounds preconditioner_rhs[i, j, k] = rhs[i, j, k] * V⁻¹ᶜᶜᶜ(i, j, k, grid)
132151
end
133152

134153
function compute_preconditioner_rhs!(solver::FFTBasedPoissonSolver, rhs)
135154
grid = solver.grid
136155
arch = architecture(grid)
137-
launch!(arch, grid, :xyz, fft_preconditioner_rhs!, solver.storage, rhs)
156+
launch!(arch, grid, :xyz, fft_preconditioner_rhs!, solver.storage, rhs, grid)
138157
return nothing
139158
end
140159

@@ -149,7 +168,7 @@ end
149168

150169
const FFTBasedPreconditioner = Union{FFTBasedPoissonSolver, FourierTridiagonalPoissonSolver}
151170

152-
function precondition!(p, preconditioner::FFTBasedPreconditioner, r, args...)
171+
@inline function precondition!(p, preconditioner::FFTBasedPreconditioner, r, args...)
153172
compute_preconditioner_rhs!(preconditioner, r)
154173
solve!(p, preconditioner, preconditioner.storage)
155174
return p

src/Solvers/fourier_tridiagonal_poisson_solver.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,15 @@ end
255255

256256
@kernel function multiply_by_spacing!(b, ::XDirection, grid)
257257
i, j, k = @index(Global, NTuple)
258-
@inbounds b[i, j, k] *= Δxᶜᵃᵃ(i, j, k, grid)
258+
@inbounds b[i, j, k] *= Δxᶜᶜᶜ(i, j, k, grid)
259259
end
260260

261261
@kernel function multiply_by_spacing!(b, ::YDirection, grid)
262262
i, j, k = @index(Global, NTuple)
263-
@inbounds b[i, j, k] *= Δyᵃᶜᵃ(i, j, k, grid)
263+
@inbounds b[i, j, k] *= Δyᶜᶜᶜ(i, j, k, grid)
264264
end
265265

266266
@kernel function multiply_by_spacing!(b, ::ZDirection, grid)
267267
i, j, k = @index(Global, NTuple)
268-
@inbounds b[i, j, k] *= Δzᵃᵃᶜ(i, j, k, grid)
268+
@inbounds b[i, j, k] *= Δzᶜᶜᶜ(i, j, k, grid)
269269
end
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
using Oceananigans
2+
using Printf
3+
using JLD2
4+
using Oceananigans.Models.NonhydrostaticModels: ConjugateGradientPoissonSolver, FFTBasedPoissonSolver, FourierTridiagonalPoissonSolver
5+
using Oceananigans.Models.NonhydrostaticModels: nonhydrostatic_pressure_solver
6+
using Oceananigans.Solvers: DiagonallyDominantPreconditioner, compute_laplacian!
7+
using Oceananigans.Grids: with_number_type, XYZRegularRG
8+
using Oceananigans.Utils: launch!
9+
using KernelAbstractions: @kernel, @index
10+
using Oceananigans.Architectures: architecture
11+
using Oceananigans.Operators
12+
using Statistics
13+
using CairoMakie
14+
using Random
15+
16+
rng = Xoshiro(123)
17+
18+
function initial_conditions!(model)
19+
h = 0.05
20+
x₀ = 0.5
21+
y₀ = 0.5
22+
z₀ = 0.55
23+
bᵢ(x, y, z) = - exp(-((x - x₀)^2 + (y - y₀)^2 + (z - z₀)^2) / 2h^2)
24+
25+
set!(model, b=bᵢ)
26+
end
27+
28+
function setup_grid(Nx, Ny, Nz, arch)
29+
zs = collect(range(0, 1, length=Nz+1))
30+
zs[2:end-1] .+= randn(length(zs[2:end-1])) * (1 / Nz) / 10
31+
32+
grid = RectilinearGrid(arch, Float64,
33+
size = (Nx, Ny, Nz),
34+
halo = (4, 4, 4),
35+
x = (0, 1),
36+
y = (0, 1),
37+
z = (0, 1),
38+
# z = zs,
39+
topology = (Bounded, Bounded, Bounded))
40+
41+
slope(x, y) = (5 + tanh(40*(x - 1/6)) + tanh(40*(x - 2/6)) + tanh(40*(x - 3/6)) + tanh(40*(x - 4/6)) + tanh(40*(x - 5/6))) / 20 +
42+
(5 + tanh(40*(y - 1/6)) + tanh(40*(y - 2/6)) + tanh(40*(y - 3/6)) + tanh(40*(y - 4/6)) + tanh(40*(y - 5/6))) / 20
43+
44+
# grid = ImmersedBoundaryGrid(grid, GridFittedBottom(slope))
45+
grid = ImmersedBoundaryGrid(grid, PartialCellBottom(slope))
46+
return grid
47+
end
48+
49+
function setup_model(grid, pressure_solver)
50+
model = NonhydrostaticModel(; grid, pressure_solver,
51+
advection = WENO(),
52+
coriolis = FPlane(f=0.1),
53+
tracers = :b,
54+
buoyancy = BuoyancyTracer())
55+
56+
initial_conditions!(model)
57+
return model
58+
end
59+
60+
@kernel function _divergence!(target_field, u, v, w, grid)
61+
i, j, k = @index(Global, NTuple)
62+
@inbounds target_field[i, j, k] = divᶜᶜᶜ(i, j, k, grid, u, v, w)
63+
end
64+
65+
function compute_flow_divergence!(target_field, model)
66+
grid = model.grid
67+
u, v, w = model.velocities
68+
arch = architecture(grid)
69+
launch!(arch, grid, :xyz, _divergence!, target_field, u, v, w, grid)
70+
return nothing
71+
end
72+
73+
74+
function setup_simulation(model)
75+
Δt = 1e-3
76+
simulation = Simulation(model; Δt = Δt, stop_time = 10, minimum_relative_step = 1e-10)
77+
conjure_time_step_wizard!(simulation, cfl=0.7, IterationInterval(1))
78+
79+
wall_time = Ref(time_ns())
80+
81+
d = Field{Center, Center, Center}(grid)
82+
83+
function progress(sim)
84+
pressure_solver = sim.model.pressure_solver
85+
86+
if pressure_solver isa ConjugateGradientPoissonSolver
87+
pressure_iters = iteration(pressure_solver)
88+
else
89+
pressure_iters = 0
90+
end
91+
92+
msg = @sprintf("iter: %d, time: %s, Δt: %.4f, Poisson iters: %d",
93+
iteration(sim), prettytime(time(sim)), sim.Δt, pressure_iters)
94+
95+
elapsed = 1e-9 * (time_ns() - wall_time[])
96+
97+
compute_flow_divergence!(d, sim.model)
98+
99+
msg *= @sprintf(", max u: %6.3e, max v: %6.3e, max w: %6.3e, max b: %6.3e, max d: %6.3e, max pressure: %6.3e, wall time: %s",
100+
maximum(sim.model.velocities.u),
101+
maximum(sim.model.velocities.v),
102+
maximum(sim.model.velocities.w),
103+
maximum(sim.model.tracers.b),
104+
maximum(d),
105+
maximum(sim.model.pressures.pNHS),
106+
prettytime(elapsed))
107+
108+
@info msg
109+
wall_time[] = time_ns()
110+
111+
return nothing
112+
end
113+
114+
simulation.callbacks[:progress] = Callback(progress, IterationInterval(1))
115+
116+
compute_flow_divergence!(d, model)
117+
118+
B = Field(Integral(model.tracers.b))
119+
120+
outputs = merge(model.velocities, model.tracers, (; p=model.pressures.pNHS, d, B))
121+
122+
if grid.underlying_grid isa XYZRegularRG
123+
file_prefix = "uniform_"
124+
else
125+
file_prefix = "nonuniform_"
126+
end
127+
128+
file_prefix *= "staircase_2D_convection"
129+
130+
if preconditioner isa FFTBasedPoissonSolver
131+
file_prefix *= "_cgfft"
132+
elseif preconditioner isa FourierTridiagonalPoissonSolver
133+
file_prefix *= "_cgftri"
134+
else
135+
file_prefix *= "_cgnoprec"
136+
end
137+
138+
if grid.immersed_boundary isa PartialCellBottom
139+
file_prefix *= "_partialcellbottom"
140+
else
141+
file_prefix *= "_gridfittedbottom"
142+
end
143+
144+
filename = "./$(file_prefix)"
145+
simulation.output_writers[:jld2] = JLD2Writer(model, outputs;
146+
filename = filename,
147+
schedule = TimeInterval(0.1),
148+
overwrite_existing = true)
149+
150+
return simulation, file_prefix
151+
end
152+
153+
arch = GPU()
154+
Nx = Ny = Nz = 32
155+
grid = setup_grid(Nx, Ny, Nz, arch)
156+
157+
@info "Create pressure solver"
158+
159+
# preconditioner = nonhydrostatic_pressure_solver(grid)
160+
preconditioner = nonhydrostatic_pressure_solver(with_number_type(Float32, grid.underlying_grid))
161+
# preconditioner = nothing
162+
163+
pressure_solver = ConjugateGradientPoissonSolver(grid, maxiter=10000; preconditioner)
164+
165+
model = setup_model(grid, pressure_solver)
166+
167+
simulation, filename = setup_simulation(model)
168+
169+
run!(simulation)
170+
171+
#%%
172+
bt = FieldTimeSeries(filename, "b")
173+
ut = FieldTimeSeries(filename, "u")
174+
wt = FieldTimeSeries(filename, "w")
175+
pt = FieldTimeSeries(filename, "p")
176+
δt = FieldTimeSeries(filename, "d")
177+
times = bt.times
178+
Nt = length(times)
179+
180+
Bt = FieldTimeSeries(filename, "B")
181+
#%%
182+
yloc = Nz ÷ 2
183+
184+
fig = Figure(size=(1200, 1200))
185+
186+
n = Observable(1)
187+
188+
B₀ = sum(interior(bt[1], :, 1, :)) / (Nx * Nz)
189+
btitlestr = @lift @sprintf("Buoyancy at t = %.2f", times[$n])
190+
utitlestr = @lift @sprintf("Horizontal velocity at t = %.2f", times[$n])
191+
wtitlestr = @lift @sprintf("Vertical velocity at t = %.2f", times[$n])
192+
193+
δlim = 1e-9
194+
195+
axb = Axis(fig[1, 1], title=btitlestr)
196+
axu = Axis(fig[1, 2], title=utitlestr)
197+
axw = Axis(fig[1, 3], title=wtitlestr)
198+
axp = Axis(fig[2, 1], title="Pressure")
199+
axd = Axis(fig[2, 2], title="Divergence, lim = $(δlim)")
200+
axt = Axis(fig[3, 1:3], xlabel="Time", ylabel="Fractional remaining tracer")
201+
202+
bn = @lift interior(bt[$n], :, yloc, :)
203+
un = @lift interior(ut[$n], :, yloc, :)
204+
wn = @lift interior(wt[$n], :, yloc, :)
205+
pn = @lift interior(pt[$n], :, yloc, :)
206+
δn = @lift interior(δt[$n], :, yloc, :)
207+
208+
ulim = maximum(abs, ut) / 2
209+
wlim = maximum(abs, wt) / 2
210+
plim = maximum(abs, pt) / 2
211+
212+
heatmap!(axb, bn, colormap=:balance, colorrange=(-0.5, 0.5))
213+
heatmap!(axu, un, colormap=:balance, colorrange=(-ulim, ulim))
214+
heatmap!(axw, wn, colormap=:balance, colorrange=(-wlim, wlim))
215+
heatmap!(axp, pn, colormap=:balance, colorrange=(-plim, plim))
216+
heatmap!(axd, δn, colormap=:balance, colorrange=(-δlim, δlim))
217+
218+
ΔB = Bt.data[1, 1, 1, :] .- Bt.data[1, 1, 1, 1]
219+
t = @lift times[$n]
220+
lines!(axt, times, ΔB)
221+
vlines!(axt, t, color=:black)
222+
# display(fig)
223+
224+
CairoMakie.record(fig, "./$(filename).mp4", 1:Nt, framerate=15) do nn
225+
n[] = nn
226+
end

0 commit comments

Comments
 (0)