Skip to content

Commit 6e88c52

Browse files
authored
Removing the overwrite of Base.convert, creating a new function instead (#179)
* Removing the overwrite of Base.convert, creating a new function instead * Removing the function mixprec_convert, using .= * putting two equals into one line in time integration code * making more single line equals * removing the double .=
1 parent cb1aadf commit 6e88c52

File tree

4 files changed

+64
-51
lines changed

4 files changed

+64
-51
lines changed

src/ghost_points.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,34 @@ function add_halo( u::Array{T,2},
6868
return u,v,η,sst
6969
end
7070

71+
function add_halo(u::Array{T,2},
72+
v::Array{T,2},
73+
η::Array{T,2},
74+
S::ModelSetup) where {T<:AbstractFloat}
75+
76+
@unpack nx,ny,nux,nuy,nvx,nvy = S.grid
77+
@unpack halo,haloη,halosstx,halossty = S.grid
78+
79+
# Add zeros to satisfy kinematic boundary conditions
80+
u = cat(zeros(T,halo,nuy),u,zeros(T,halo,nuy),dims=1)
81+
u = cat(zeros(T,nux+2*halo,halo),u,zeros(T,nux+2*halo,halo),dims=2)
82+
83+
v = cat(zeros(T,halo,nvy),v,zeros(T,halo,nvy),dims=1)
84+
v = cat(zeros(T,nvx+2*halo,halo),v,zeros(T,nvx+2*halo,halo),dims=2)
85+
86+
η = cat(zeros(T,haloη,ny),η,zeros(T,haloη,ny),dims=1)
87+
η = cat(zeros(T,nx+2*haloη,haloη),η,zeros(T,nx+2*haloη,haloη),dims=2)
88+
89+
# SCALING
90+
@unpack scale,scale_sst = S.constants
91+
u *= scale
92+
v *= scale
93+
94+
ghost_points!(u,v,η,S)
95+
96+
return u,v,η
97+
end
98+
7199
"""Cut off the halo from the prognostic variables."""
72100
function remove_halo( u::Array{T,2},
73101
v::Array{T,2},
@@ -132,7 +160,7 @@ end
132160
"""Decide on boundary condition P.bc which ghost point function to execute."""
133161
function ghost_points_uv!( u::AbstractMatrix,
134162
v::AbstractMatrix,
135-
P::Parameter,
163+
P::Parameter,
136164
C::Constants)
137165

138166
@unpack bc,Tcomm = P

src/model_setup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,5 @@ mutable struct ModelSetup{T<:AbstractFloat,Tprog<:AbstractFloat}
2626
forcing::Forcing{T}
2727
Prog::PrognosticVars{Tprog}
2828
Diag::DiagnosticVars{T, Tprog}
29-
t::Int # SW: I believe this has something to do with Checkpointing, need to verify
29+
t::Int
3030
end

src/time_integration.jl

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
2828
Ixy!(Diag.Vorticity.h_q,Diag.VolumeFluxes.h)
2929

3030
# calculate PV terms for initial conditions
31-
urhs = convert(Diag.PrognosticVarsRHS.u,u)
32-
vrhs = convert(Diag.PrognosticVarsRHS.v,v)
33-
ηrhs = convert(Diag.PrognosticVarsRHS.η,η)
31+
urhs = Diag.PrognosticVarsRHS.u .= u
32+
vrhs = Diag.PrognosticVarsRHS.v .= v
33+
ηrhs = Diag.PrognosticVarsRHS.η .= η
34+
3435
advection_coriolis!(urhs,vrhs,ηrhs,Diag,S)
3536
PVadvection!(Diag,S)
3637

@@ -71,9 +72,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
7172
end
7273

7374
# type conversion for mixed precision
74-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
75-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
76-
η1rhs = convert(Diag.PrognosticVarsRHS.η,η1)
75+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
76+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
77+
η1rhs = Diag.PrognosticVarsRHS.η .= η1
7778

7879
rhs!(u1rhs,v1rhs,η1rhs,Diag,S,t) # momentum only
7980
continuity!(u1rhs,v1rhs,η1rhs,Diag,S,t) # continuity equation
@@ -118,9 +119,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
118119
end
119120

120121
# type conversion for mixed precision
121-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
122-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
123-
η1rhs = convert(Diag.PrognosticVarsRHS.η,η1)
122+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
123+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
124+
η1rhs = Diag.PrognosticVarsRHS.η .= η1
124125

125126
rhs!(u1rhs,v1rhs,η1rhs,Diag,S,t) # momentum only
126127

@@ -130,8 +131,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
130131

131132
# semi-implicit for continuity equation, use new u1,v1 to calcualte dη
132133
ghost_points_uv!(u1,v1,S)
133-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
134-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
134+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
135+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
136+
135137
continuity!(u1rhs,v1rhs,η1rhs,Diag,S,t)
136138
axb!(η1,Δt_Δs,dη) # η1 = η1 + Δt/(s-1)*RHS(u1)
137139
end
@@ -158,9 +160,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
158160
end
159161

160162
# type conversion for mixed precision
161-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
162-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
163-
η1rhs = convert(Diag.PrognosticVarsRHS.η,η1)
163+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
164+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
165+
η1rhs = Diag.PrognosticVarsRHS.η .= η1
164166

165167
rhs!(u1rhs,v1rhs,η1rhs,Diag,S,t)
166168

@@ -179,8 +181,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
179181

180182
# semi-implicit for continuity equation, use new u1,v1 to calcualte dη
181183
ghost_points_uv!(u1,v1,S)
182-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
183-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
184+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
185+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
186+
184187
continuity!(u1rhs,v1rhs,η1rhs,Diag,S,t)
185188

186189
if rki == kn
@@ -210,9 +213,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
210213
end
211214

212215
# type conversion for mixed precision
213-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
214-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
215-
η1rhs = convert(Diag.PrognosticVarsRHS.η,η1)
216+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
217+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
218+
η1rhs = Diag.PrognosticVarsRHS.η .= η1
216219

217220
rhs!(u1rhs,v1rhs,η1rhs,Diag,S,t)
218221

@@ -223,8 +226,10 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
223226

224227
# semi-implicit for continuity equation, use u1,v1 to calcualte dη
225228
ghost_points_uv!(u1,v1,S)
226-
u1rhs = convert(Diag.PrognosticVarsRHS.u,u1)
227-
v1rhs = convert(Diag.PrognosticVarsRHS.v,v1)
229+
230+
u1rhs = Diag.PrognosticVarsRHS.u .= u1
231+
v1rhs = Diag.PrognosticVarsRHS.v .= v1
232+
228233
continuity!(u1rhs,v1rhs,η1rhs,Diag,S,t)
229234

230235
caxb!(η0,η1,Δt_Δ,dη) # store Euler update into η0
@@ -245,9 +250,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
245250
ghost_points!(u0,v0,η0,S)
246251

247252
# type conversion for mixed precision
248-
u0rhs = convert(Diag.PrognosticVarsRHS.u,u0)
249-
v0rhs = convert(Diag.PrognosticVarsRHS.v,v0)
250-
η0rhs = convert(Diag.PrognosticVarsRHS.η,η0)
253+
u0rhs = Diag.PrognosticVarsRHS.u .= u0
254+
v0rhs = Diag.PrognosticVarsRHS.v .= v0
255+
η0rhs = Diag.PrognosticVarsRHS.η .= η0
251256

252257
# ADVECTION and CORIOLIS TERMS
253258
# although included in the tendency of every RK substep,
@@ -270,8 +275,9 @@ function time_integration(S::ModelSetup{T,Tprog}) where {T<:AbstractFloat,Tprog<
270275
t += dtint
271276

272277
# TRACER ADVECTION
273-
u0rhs = convert(Diag.PrognosticVarsRHS.u,u0) # copy back as add_drag_diff_tendencies changed u0,v0
274-
v0rhs = convert(Diag.PrognosticVarsRHS.v,v0)
278+
u0rhs = Diag.PrognosticVarsRHS.u .= u0
279+
v0rhs = Diag.PrognosticVarsRHS.v .= v0
280+
275281
tracer!(i,u0rhs,v0rhs,Prog,Diag,S)
276282

277283
# feedback and output
@@ -391,25 +397,4 @@ function dxaybzc!( d::Array{T,2},
391397
d[i,j] = xT*a[i,j] + yT*b[i,j] + zT*c[i,j]
392398
end
393399
end
394-
end
395-
396-
"""Convert function for two arrays, X1, X2, in case their eltypes differ.
397-
Convert every element from X1 and store it in X2."""
398-
function Base.convert(X2::Array{T2,N},X1::Array{T1,N}) where {T1,T2,N}
399-
400-
@boundscheck size(X2) == size(X1) || throw(BoundsError())
401-
402-
@inbounds for i in eachindex(X1)
403-
X2[i] = convert(T2,X1[i])
404-
end
405-
406-
return X2
407-
end
408-
409-
410-
"""Convert function for two arrays, X1, X2, in case their eltypes are identical.
411-
Just pass X1, such that X2 is pointed to the same place in memory."""
412-
function Base.convert(X2::Array{T,N},X1::Array{T,N}) where {T,N}
413-
@boundscheck size(X2) == size(X1) || throw(BoundsError())
414-
return X1
415-
end
400+
end

src/tracer_advection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function tracer!( i::Integer,
2323
@unpack ssti,sst_ref,dsst_comp = Diag.SemiLagrange
2424

2525
# convert to type T for mixed precision
26-
sstrhs = convert(Diag.PrognosticVarsRHS.sst,sst)
26+
sstrhs = Diag.PrognosticVarsRHS.sst .= sst
2727

2828
departure!(u,v,Diag,S)
2929
adv_sst!(sstrhs,Diag,S)

0 commit comments

Comments
 (0)