Skip to content

Commit 22ff45c

Browse files
authored
Merge pull request #155 from milankl/scale
scale also sst
2 parents 03faffb + 922cbb4 commit 22ff45c

File tree

9 files changed

+27
-33
lines changed

9 files changed

+27
-33
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ShallowWaters.jl - A type-flexible 16-bit shallow water model
22
[![CI](https://github.com/milankl/ShallowWaters.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/milankl/ShallowWaters.jl/actions/workflows/CI.yml)
33
[![DOI](https://zenodo.org/badge/132787050.svg)](https://zenodo.org/badge/latestdoi/132787050)
4-
![sst](figs/sst_posit16.png?raw=true "SST")
4+
![sst](figs/isambard_float16.png?raw=true "Float16 simulation with ShallowWaters.jl on Isambard's A64FX")
55

66
A shallow water model with a focus on type-flexibility and 16-bit number formats. ShallowWaters allows for Float64/32/16,
77
[Posit32/16/8](https://github.com/milankl/SoftPosit.jl), [BFloat16](https://github.com/JuliaComputing/BFloat16s.jl),

figs/isambard_float16.png

779 KB
Loading

src/default_parameters.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
Tprog=T # number format for prognostic variables
66
Tcomm=Tprog # number format for ghost-point copies
7+
Tini=Tprog # number format to reduce precision for initial conditions
78

89
# DOMAIN RESOLUTION AND RATIO
910
nx::Int=100 # number of grid cells in x-direction
@@ -157,6 +158,7 @@ Creates a Parameter struct with following options and default values
157158
158159
Tprog=T # number format for prognostic variables
159160
Tcomm=Tprog # number format for ghost-point copies
161+
Tini=Tprog # number format to reduce precision for initial conditions
160162
161163
# DOMAIN RESOLUTION AND RATIO
162164
nx::Int=100 # number of grid cells in x-direction

src/feedback.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,4 @@ function progress!(feedback::Feedback)
173173
flush(progress_txt)
174174
end
175175
end
176-
end
176+
end

src/ghost_points.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ function add_halo( u::Array{T,2},
2525
@unpack scale = S.constants
2626
u *= scale
2727
v *= scale
28+
sst *= scale
2829

2930
ghost_points!(u,v,η,S)
3031
ghost_points_sst!(sst,S)
@@ -41,10 +42,11 @@ function remove_halo( u::Array{T,2},
4142
@unpack halo,haloη,halosstx,halossty = S.grid
4243
@unpack scale_inv = S.constants
4344

45+
# undo scaling as well
4446
ucut = scale_inv*u[halo+1:end-halo,halo+1:end-halo]
4547
vcut = scale_inv*v[halo+1:end-halo,halo+1:end-halo]
4648
ηcut = η[haloη+1:end-haloη,haloη+1:end-haloη]
47-
sstcut = sst[halosstx+1:end-halosstx,halossty+1:end-halossty]
49+
sstcut = scale_inv*sst[halosstx+1:end-halosstx,halossty+1:end-halossty]
4850

4951
return ucut,vcut,ηcut,sstcut
5052
end

src/gradients.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""Calculates the 2nd order centred gradient in x-direction on any grid (u,v,T or q).
22
The size of dudx must be m-1,n compared to m,n = size(u)"""
3-
function ∂x!(∂ₓu::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
4-
m,n = size(∂ₓu)
3+
function ∂x!(dudx::Matrix{T},u::Matrix{T}) where {T<:AbstractFloat}
4+
m,n = size(dudx)
55
@boundscheck (m+1,n) == size(u) || throw(BoundsError())
66

7-
@inbounds for j 1:n
8-
for i 1:m
9-
∂ₓu[i,j] = u[i+1,j] - u[i,j]
10-
end
7+
@inbounds for j 1:n, i 1:m
8+
dudx[i,j] = u[i+1,j] - u[i,j]
119
end
1210
end
1311

@@ -17,16 +15,14 @@ function ∂y!(dudy::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
1715
m,n = size(dudy)
1816
@boundscheck (m,n+1) == size(u) || throw(BoundsError())
1917

20-
@inbounds for j 1:n
21-
for i 1:m
18+
@inbounds for j 1:n, i 1:m
2219
dudy[i,j] = u[i,j+1] - u[i,j]
23-
end
2420
end
2521
end
2622

2723
""" ∇² is the 2nd order centred Laplace-operator ∂/∂x^2 + ∂/∂y^2.
2824
The 1/Δ²-factor is omitted and moved into the viscosity coefficient."""
29-
function ∇²!(du::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
25+
function ∇²!(du::Matrix{T},u::Matrix{T}) where {T<:AbstractFloat}
3026
m, n = size(du)
3127
@boundscheck (m+2,n+2) == size(u) || throw(BoundsError())
3228

@@ -91,4 +87,4 @@ function ∇²(u::Array{T,2},Δ::Real=1) where {T<:AbstractFloat}
9187
end
9288
end
9389
return du
94-
end
90+
end

src/initial_conditions.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ end
2626
function initial_conditions(::Type{T},S::ModelSetup) where {T<:AbstractFloat}
2727

2828
## PROGNOSTIC VARIABLES U,V,η
29-
3029
@unpack nux,nuy,nvx,nvy,nx,ny = S.grid
3130
@unpack initial_cond = S.parameters
31+
@unpack Tini = S.parameters
3232

3333
if initial_cond == "rest"
3434

@@ -53,15 +53,12 @@ function initial_conditions(::Type{T},S::ModelSetup) where {T<:AbstractFloat}
5353
end
5454

5555
u = ncu.vars["u"][:,:,init_starti]
56-
# NetCDF.close(ncu)
5756

5857
ncv = NetCDF.open(joinpath(inirunpath,"v.nc"))
5958
v = ncv.vars["v"][:,:,init_starti]
60-
# NetCDF.close(ncv)
6159

6260
ncη = NetCDF.open(joinpath(inirunpath,"eta.nc"))
6361
η = ncη.vars["eta"][:,:,init_starti]
64-
# NetCDF.close(ncη)
6562

6663
# remove singleton time dimension
6764
u = reshape(u,size(u)[1:2])
@@ -122,7 +119,7 @@ function initial_conditions(::Type{T},S::ModelSetup) where {T<:AbstractFloat}
122119
## SST
123120

124121
@unpack SSTmin, SSTmax, SSTw, SSTϕ = S.parameters
125-
@unpack sst_initial = S.parameters
122+
@unpack sst_initial,scale = S.parameters
126123
@unpack x_T,y_T,Lx,Ly = S.grid
127124

128125
xx_T,yy_T = meshgrid(x_T,y_T)
@@ -145,19 +142,18 @@ function initial_conditions(::Type{T},S::ModelSetup) where {T<:AbstractFloat}
145142
if initial_cond == "ncfile" && sst_initial == "restart"
146143
ncsst = NetCDF.open(joinpath(inirunpath,"sst.nc"))
147144
sst = ncsst.vars["sst"][:,:,init_starti]
148-
# NetCDF.close(ncsst)
149145

150146
sst = reshape(sst,size(sst)[1:2])
151147
end
152148

153149
# Convert to number format T
154-
sst = T.(sst)
155-
u = T.(u)
156-
v = T.(v)
157-
η = T.(η)
150+
# allow for comparable initial conditions via Tini
151+
sst = T.(Tini.(sst))
152+
u = T.(Tini.(u))
153+
v = T.(Tini.(v))
154+
η = T.(Tini.(η))
158155

159156
#TODO SST INTERPOLATION
160-
161157
u,v,η,sst = add_halo(u,v,η,sst,S)
162158

163159
return PrognosticVars{T}(u,v,η,sst)

src/interpolations.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Linear interpolation of a variable u in the x-direction.
22
m,n = size(ux) must be m+1,n = size(u)."""
3-
function Ix!(ux::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
3+
function Ix!(ux::Matrix{T},u::Matrix{T}) where {T<:AbstractFloat}
44
m, n = size(ux)
55
@boundscheck (m+1,n) == size(u) || throw(BoundsError())
66

@@ -15,7 +15,7 @@ end
1515

1616
""" Linear interpolation a variable u in the y-direction.
1717
m,n = size(uy) must be m,n+1 = size(u)."""
18-
function Iy!(uy::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
18+
function Iy!(uy::Matrix{T},u::Matrix{T}) where {T<:AbstractFloat}
1919
m,n = size(uy)
2020
@boundscheck (m,n+1) == size(u) || throw(BoundsError())
2121

@@ -36,11 +36,9 @@ function Ixy!(uxy::Array{T,2},u::Array{T,2}) where {T<:AbstractFloat}
3636

3737
one_quarter = convert(T,0.25)
3838

39-
@inbounds for j 1:n
40-
for i 1:m
41-
uxy[i,j] = one_quarter*(u[i,j] + u[i+1,j]) +
42-
one_quarter*(u[i,j+1] + u[i+1,j+1])
43-
end
39+
@inbounds for j in 1:n, i in 1:m
40+
uxy[i,j] = one_quarter*(u[i,j] + u[i+1,j]) +
41+
one_quarter*(u[i,j+1] + u[i+1,j+1])
4442
end
4543
end
4644

src/output.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ function output_nc!(i::Int,
105105
NetCDF.putvar(ncs.η,"eta",η,start=[1,1,iout],count=[-1,-1,1])
106106
end
107107
if ncs.sst != nothing
108-
@views sst = Float32.(Prog.sst[halosstx+1:end-halosstx,halossty+1:end-halossty])
108+
@views sst = Float32.(scale_inv*Prog.sst[halosstx+1:end-halosstx,halossty+1:end-halossty])
109109
NetCDF.putvar(ncs.sst,"sst",sst,start=[1,1,iout],count=[-1,-1,1])
110110
end
111111
if ncs.q != nothing

0 commit comments

Comments
 (0)