diff --git a/CondaPkg.toml b/CondaPkg.toml index dde757656..f7b6a5d50 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -1,7 +1,8 @@ [pip.deps] -copernicusmarine = ">=2.0.0" xarray = ">=2024.7.0" -numpy = ">=2.0.0" jax = ">=0.6" +copernicusmarine = ">=2.0.0" +numpy = "==2.3.1" tensorflow = ">=2.17" +veros = "" diff --git a/experiments/veros_forced_simulation.jl b/experiments/veros_forced_simulation.jl new file mode 100644 index 000000000..52d0ed853 --- /dev/null +++ b/experiments/veros_forced_simulation.jl @@ -0,0 +1,113 @@ +using ClimaOcean +using PythonCall +using Oceananigans +using Printf + +##### +##### A Prognostic Python Ocean (Veros) Simulation +##### + +# We import the Veros 4 degree ocean simulation setup, which consists of a near-global ocean +# with a uniform resolution of 4 degrees in both latitude and longitude and a latitude range spanning +# from 80S to 80N. The setup is defined in the `veros.setups.global_4deg` module. + +# Before importing the setup, we need to ensure that the Veros module is loaded +# and that every output is removed to avoid conflicts. + +VerosModule = Base.get_extension(ClimaOcean, :ClimaOceanPythonCallExt) +VerosModule.remove_outputs(:global_4deg) + +# Actually loading and instantiating the Veros setup in the variable `ocean`. +# This setup uses by default a different time-step for tracers and momentum, +# so we set it to the same value (1800 seconds) for both. + +ocean = VerosModule.VerosOceanSimulation("global_4deg", :GlobalFourDegreeSetup) + +set!(ocean, "dt_tracer", 1800.0; path=:settings) +set!(ocean, "dt_mom", 1800.0; path=:settings) + +##### +##### A Prescribed Atmosphere (JRA55) +##### + +atmos = JRA55PrescribedAtmosphere(; backend = JRA55NetCDFBackend(10)) + +##### +##### An ice-free ocean forced by a prescribed atmosphere +##### + +radiation = Radiation() +coupled_model = OceanSeaIceModel(ocean, nothing; atmosphere=atmos, radiation) +simulation = Simulation(coupled_model; Δt = 1800, stop_iteration = 100000) + +##### +##### A simple progress callback +##### + +# We set up a progress callback that will print the current time, iteration, and maximum velocities +# at every 5 iterations. It also collects the surface velocity fields and the net fluxes +# into the arrays `s`, `tx`, and `ty` for later visualization. + +wall_time = Ref(time_ns()) + +s = [] +tx = [] +ty = [] + +us = coupled_model.interfaces.exchanger.exchange_ocean_state.u +vs = coupled_model.interfaces.exchanger.exchange_ocean_state.v + +stmp = Field(sqrt(us^2 + vs^2)) + +function progress(sim) + ocean = sim.model.ocean + umax = maximum(PyArray(ocean.setup.state.variables.u)) + vmax = maximum(PyArray(ocean.setup.state.variables.v)) + wmax = maximum(PyArray(ocean.setup.state.variables.w)) + + step_time = 1e-9 * (time_ns() - wall_time[]) + + msg1 = @sprintf("time: %s, iteration: %d, Δt: %s, ", prettytime(sim), iteration(sim), prettytime(sim.Δt)) + msg5 = @sprintf("maximum(u): (%.2f, %.2f, %.2f) m/s, ", umax, vmax, wmax) + msg6 = @sprintf("wall time: %s \n", prettytime(step_time)) + + @info msg1 * msg5 * msg6 + + wall_time[] = time_ns() + + compute!(stmp) + push!(s, deepcopy(interior(stmp, :, :, 1))) + push!(tx, deepcopy(interior(coupled_model.interfaces.net_fluxes.ocean_surface.u, :, :, 1) .* 1020)) + push!(ty, deepcopy(interior(coupled_model.interfaces.net_fluxes.ocean_surface.v, :, :, 1) .* 1020)) + + return nothing +end + +add_callback!(simulation, progress, IterationInterval(5)) + +##### +##### Let's go!! +##### + +run!(simulation) + +##### +##### Visualize +##### + +iter = Observable(1) +si = @lift(s[$iter]) +txi = @lift(tx[$iter]) +tyi = @lift(ty[$iter]) + +fig = Figure(resolution = (1200, 300)) +ax1 = Axis(fig[1, 1]; title = "Surface speed (m/s)", xlabel = "Longitude", ylabel = "Latitude") +ax2 = Axis(fig[1, 2]; title = "Zonal wind stress (N/m²)", xlabel = "Longitude") +ax3 = Axis(fig[1, 3]; title = "Meridional wind stress (N/m²)", xlabel = "Longitude") + +grid = coupled_model.interfaces.exchanger.exchange_grid +λ = λnodes(grid, Center()) +φ = φnodes(grid, Center()) +heatmap!(ax1, λ, φ, si, colormap = :ice, colorrange = (0, 0.15)) +heatmap!(ax2, λ, φ, txi, colormap = :bwr, colorrange = (-0.2, 0.2)) +heatmap!(ax3, λ, φ, tyi, colormap = :bwr, colorrange = (-0.2, 0.2)) \ No newline at end of file diff --git a/ext/ClimaOceanPythonCallExt/ClimaOceanPythonCallExt.jl b/ext/ClimaOceanPythonCallExt/ClimaOceanPythonCallExt.jl new file mode 100644 index 000000000..d00b4b172 --- /dev/null +++ b/ext/ClimaOceanPythonCallExt/ClimaOceanPythonCallExt.jl @@ -0,0 +1,15 @@ +module ClimaOceanPythonCallExt + +using ClimaOcean +using CondaPkg +using PythonCall +using Oceananigans +using Oceananigans.DistributedComputations: @root + +using Dates: DateTime + +include("copernicus.jl") +include("veros_ocean_simulation.jl") +include("veros_state_exchanger.jl") + +end # module ClimaOceanPythonCallExt diff --git a/ext/ClimaOceanPythonCallExt.jl b/ext/ClimaOceanPythonCallExt/copernicus.jl similarity index 99% rename from ext/ClimaOceanPythonCallExt.jl rename to ext/ClimaOceanPythonCallExt/copernicus.jl index 52e71fe4e..1cafae4d1 100644 --- a/ext/ClimaOceanPythonCallExt.jl +++ b/ext/ClimaOceanPythonCallExt/copernicus.jl @@ -119,5 +119,3 @@ function depth_bounds_kw(z) maximum_depth = - z[1] return (; minimum_depth, maximum_depth) end - -end # module ClimaOceanPythonCallExt diff --git a/ext/ClimaOceanPythonCallExt/veros_ocean_simulation.jl b/ext/ClimaOceanPythonCallExt/veros_ocean_simulation.jl new file mode 100644 index 000000000..2e6b99115 --- /dev/null +++ b/ext/ClimaOceanPythonCallExt/veros_ocean_simulation.jl @@ -0,0 +1,202 @@ +using CondaPkg + +using Oceananigans.Grids: topology +using ClimaOcean.OceanSeaIceModels: reference_density, heat_capacity, SeaIceSimulation + +import Oceananigans.Fields: set! +import Oceananigans.TimeSteppers: time_step!, initialize! + +import ClimaOcean.OceanSeaIceModels: OceanSeaIceModel, default_nan_checker +import Oceananigans.Architectures: architecture + +import Base: eltype + +""" + install_veros() + +Install the Veros ocean model Marine CLI using CondaPkg. +Returns a NamedTuple containing package information if successful. +""" +function install_veros() + CondaPkg.add("veros"; channel = "conda-forge") + cli = CondaPkg.which("veros") + @info "... the veros CLI has been installed at $(cli)." + return cli +end + +struct VerosOceanSimulation{S} + setup :: S +end + +default_nan_checker(model::OceanSeaIceModel{<:Any, <:Any, <:VerosOceanSimulation}) = nothing + +initialize!(::ClimaOceanPythonCallExt.VerosOceanSimulation{Py}) = nothing +time_step!(ocean::VerosOceanSimulation, Δt) = ocean.setup.step(ocean.setup.state) +architecture(model::OceanSeaIceModel{<:Any, <:Any, <:VerosOceanSimulation}) = CPU() +eltype(model::OceanSeaIceModel{<:Any, <:Any, <:VerosOceanSimulation}) = Float64 + +function remove_outputs(setup::Symbol) + rm("$(setup).averages.nc", force=true) + rm("$(setup).energy.nc", force=true) + rm("$(setup).overturning.nc", force=true) + rm("$(setup).snapshot.nc", force=true) + return nothing +end + +const CCField2D = Field{<:Center, <:Center, <:Nothing} +const FCField2D = Field{<:Face, <:Center, <:Nothing} +const CFField2D = Field{<:Center, <:Face, <:Nothing} + +function set!(field::CCField2D, pyarray::Py, k=pyconvert(Int, pyarray.shape[2])) + array = PyArray(pyarray) + Nx, Ny, Nz = size(array) + set!(field, view(array, 3:Nx-2, 3:Ny-2, k, 1)) + return field +end + +function set!(field::FCField2D, pyarray::Py, k=pyconvert(Int, pyarray.shape[2])) + array = PyArray(pyarray) + Nx, Ny, Nz = size(array) + TX, TY, _ = topology(field.grid) + i_indices = TX == Periodic ? UnitRange(3, Nx-2) : UnitRange(2, Nx-2) + set!(field, view(array, i_indices, 3:Ny-2, k, 1)) + return field +end + +function set!(field::CFField2D, pyarray::Py, k=pyconvert(Int, pyarray.shape[2])) + array = PyArray(pyarray) + Nx, Ny, Nz = size(array) + set!(field, view(array, 3:Nx-2, 2:Ny-2, k, 1)) + return field +end + +""" + VerosOceanSimulation(setup, setup_name::Symbol) + +Creates and initializes a preconfigured Veros ocean simulation using the +specified setup module and setup name. + +Arguments +========== +- `setup::AbstractString`: The name of the Veros setup module to import (e.g., `"global_4deg"`). +- `setup_name::Symbol`: The name of the setup class or function within the module to instantiate (e.g., `:GlobalFourDegreeSetup`). +""" +function VerosOceanSimulation(setup, setup_name::Symbol) + setups = pyimport("veros.setups." * setup) + setup = @eval $setups.$setup_name() + + # instantiate the setup + setup.setup() + + return VerosOceanSimulation(setup) +end + +""" + surface_grid(ocean::VerosOceanSimulation) + +Constructs a `LatitudeLongitudeGrid` representing the surface grid of the given `VerosOceanSimulation` object. +Notes: Veros always uses a LatitudeLongitudeGrid with 2 halos in both the latitude and longitude directions. +Both latitude and longitude can be either stretched or uniform, depending on the setup, and while the meridional +direction (latitude) is always Bounded, the zonal direction (longitude) can be either Periodic or Bounded. + +Arguments +========== +- `ocean::VerosOceanSimulation`: The ocean simulation object containing the grid state variables. +""" +function surface_grid(ocean::VerosOceanSimulation) + + xf = Array(PyArray(ocean.setup.state.variables.xu)) + yf = Array(PyArray(ocean.setup.state.variables.yu)) + + xc = Array(PyArray(ocean.setup.state.variables.xt)) + yc = Array(PyArray(ocean.setup.state.variables.yt)) + + xf = xf[2:end-2] + yf = yf[2:end-2] + + xc = xc[3:end-2] + yc = yc[3:end-2] + + xf[1] = xf[2] - 2xc[1] + yf[1] = sign(yf[2]) * (yf[2] - 2yc[1]) + + TX = if xf[1] == 0 && xf[end] == 360 + Periodic + else + Bounded + end + + Nx = length(xc) + Ny = length(yc) + + return LatitudeLongitudeGrid(size=(Nx, Ny), longitude=xf, latitude=yf, topology=(TX, Bounded, Flat), halo=(2, 2)) +end + +""" + set!(ocean, v, x; path = :variable) + +Set the `v` variable in the `ocean` model to the value of `x`. +the path corresponds to the path inside the class where to locate the +variable `v` to set. It can be either `:variables` or `:settings`. +""" +function set!(ocean::VerosOceanSimulation, v, x; path = :variables) + setup = ocean.setup + if path == :variables + pyexec(""" + with setup.state.variables.unlock(): + setup.state.variables.__setattr__(y, t) + """, Main, (y=v, t=x, setup=setup)) + elseif path == :settings + pyexec(""" + with setup.state.settings.unlock(): + setup.state.settings.__setattr__(y, t) + """, Main, (y=v, t=x, setup=setup)) + else + error("path must be either :variable or :settings.") + end +end + +function OceanSeaIceModel(ocean::VerosOceanSimulation, sea_ice=nothing; + atmosphere = nothing, + radiation = Radiation(), + clock = Clock(time=0), + ocean_reference_density = 1020.0, + ocean_heat_capacity = 3998.0, + sea_ice_reference_density = reference_density(sea_ice), + sea_ice_heat_capacity = heat_capacity(sea_ice), + interfaces = nothing) + + if sea_ice isa SeaIceSimulation + if !isnothing(sea_ice.callbacks) + pop!(sea_ice.callbacks, :stop_time_exceeded, nothing) + pop!(sea_ice.callbacks, :stop_iteration_exceeded, nothing) + pop!(sea_ice.callbacks, :wall_time_limit_exceeded, nothing) + pop!(sea_ice.callbacks, :nan_checker, nothing) + end + end + + # Contains information about flux contributions: bulk formula, prescribed fluxes, etc. + if isnothing(interfaces) && !(isnothing(atmosphere) && isnothing(sea_ice)) + interfaces = ComponentInterfaces(atmosphere, ocean, sea_ice; + ocean_reference_density, + ocean_heat_capacity, + sea_ice_reference_density, + sea_ice_heat_capacity, + radiation) + end + + arch = CPU() + + ocean_sea_ice_model = OceanSeaIceModel(arch, + clock, + atmosphere, + sea_ice, + ocean, + interfaces) + + # Make sure the initial temperature of the ocean + # is not below freezing and above melting near the surface + initialization_update_state!(ocean_sea_ice_model) + + return ocean_sea_ice_model +end \ No newline at end of file diff --git a/ext/ClimaOceanPythonCallExt/veros_state_exchanger.jl b/ext/ClimaOceanPythonCallExt/veros_state_exchanger.jl new file mode 100644 index 000000000..ab1a8ac34 --- /dev/null +++ b/ext/ClimaOceanPythonCallExt/veros_state_exchanger.jl @@ -0,0 +1,114 @@ +using Oceananigans.Models: initialization_update_state! + +using ClimaOcean.OceanSeaIceModels.InterfaceComputations: ExchangeAtmosphereState, + atmosphere_exchanger, + SimilarityTheoryFluxes, + Radiation + +import ClimaOcean.OceanSeaIceModels.InterfaceComputations: + state_exchanger, + sea_ice_ocean_interface, + atmosphere_ocean_interface, + initialize!, + get_ocean_state, + ocean_surface_fluxes, + get_radiative_forcing, + fill_net_fluxes! + +mutable struct VerosStateExchanger{G, OST, AST, AEX} + exchange_grid :: G + exchange_ocean_state :: OST + exchange_atmosphere_state :: AST + atmosphere_exchanger :: AEX +end + +mutable struct ExchangeOceanState{FC, CF, CC} + u :: FC + v :: CF + T :: CC + S :: CC +end + +ExchangeOceanState(grid) = ExchangeOceanState(Field{Face, Center, Nothing}(grid), + Field{Center, Face, Nothing}(grid), + Field{Center, Center, Nothing}(grid), + Field{Center, Center, Nothing}(grid)) + +function state_exchanger(ocean::VerosOceanSimulation, atmosphere) + exchange_grid = surface_grid(ocean) + exchange_ocean_state = ExchangeOceanState(exchange_grid) + exchange_atmosphere_state = ExchangeAtmosphereState(exchange_grid) + + atmos_exchanger = atmosphere_exchanger(atmosphere, exchange_grid) + + return VerosStateExchanger(exchange_grid, + exchange_ocean_state, + exchange_atmosphere_state, + atmos_exchanger) +end + +atmosphere_ocean_interface(ocean::VerosOceanSimulation, args...) = + atmosphere_ocean_interface(surface_grid(ocean), args...) + +sea_ice_ocean_interface(ocean::VerosOceanSimulation, args...) = + sea_ice_ocean_interface(surface_grid(ocean), args...) + +initialize!(exchanger::VerosStateExchanger, atmosphere) = + initialize!(exchanger.atmosphere_exchanger, exchanger.exchange_grid, atmosphere) + +@inline function get_ocean_state(ocean::VerosOceanSimulation, exchanger::VerosStateExchanger) + u = exchanger.exchange_ocean_state.u + v = exchanger.exchange_ocean_state.v + T = exchanger.exchange_ocean_state.T + S = exchanger.exchange_ocean_state.S + + set!(u, ocean.setup.state.variables.u) + set!(v, ocean.setup.state.variables.v) + set!(T, ocean.setup.state.variables.temp) + set!(S, ocean.setup.state.variables.salt) + + return (; u, v, T, S) +end + +@inline function ocean_surface_fluxes(ocean::VerosOceanSimulation, ρₒ, cₒ) + grid = surface_grid(ocean) + u = Field{Face, Center, Nothing}(grid) + v = Field{Center, Face, Nothing}(grid) + T = Field{Center, Center, Nothing}(grid) + S = Field{Center, Center, Nothing}(grid) + Q = ρₒ * cₒ * T + + return (; u, v, T, S, Q) +end + +@inline get_radiative_forcing(ocean::VerosOceanSimulation) = nothing + +function fill_net_fluxes!(ocean::VerosOceanSimulation, net_ocean_fluxes) + nx = pyconvert(Int, ocean.setup.state.settings.nx) + 4 + ny = pyconvert(Int, ocean.setup.state.settings.ny) + 4 + + ρₒ = pyconvert(eltype(ocean), ocean.setup.state.settings.rho_0) + taux = view(parent(net_ocean_fluxes.u), 1:nx, 1:ny, 1) .* ρₒ + tauy = view(parent(net_ocean_fluxes.v), 1:nx, 1:ny, 1) .* ρₒ + + # TODO: Do not do this 12 thingy when we can make sure + # that veros supports it + tx = zeros(size(taux)..., 12) + ty = zeros(size(tauy)..., 12) + for t in 1:12 + tx[:, :, t] .= taux + ty[:, :, t] .= tauy + end + + set!(ocean, "taux", tx; path=:variables) + set!(ocean, "tauy", ty; path=:variables) + + # TODO: uncomment below when veros supports prescribed fluxes BC for tracers + # temp_flux = view(parent(net_ocean_fluxes.T), 1:nx, 1:ny, 1) + # salt_flux = view(parent(net_ocean_fluxes.S), 1:nx, 1:ny, 1) + + # set!(ocean, "temp_flux", temp_flux; path=:variables) + # set!(ocean, "salt_flux", salt_flux; path=:variables) + + return nothing +end diff --git a/src/OceanSeaIceModels/InterfaceComputations/assemble_net_fluxes.jl b/src/OceanSeaIceModels/InterfaceComputations/assemble_net_fluxes.jl index 1532a3be7..1fefa089f 100644 --- a/src/OceanSeaIceModels/InterfaceComputations/assemble_net_fluxes.jl +++ b/src/OceanSeaIceModels/InterfaceComputations/assemble_net_fluxes.jl @@ -21,18 +21,22 @@ using ClimaOcean.OceanSeaIceModels: sea_ice_concentration return zero(Iˢʷ) end -get_radiative_forcing(FT) = FT -function get_radiative_forcing(FT::MultipleForcings) +@inline get_radiative_forcing(ocean::OceananigansSimulation) = get_radiative_forcing(ocean.model.forcing.T) +@inline get_radiative_forcing(FT) = FT + +@inline function get_radiative_forcing(FT::MultipleForcings) for forcing in FT.forcings forcing isa TwoColorRadiation && return forcing end return nothing end +# No need to do this for an Oceananigans Simulation +fill_net_fluxes!(ocean, net_ocean_fluxes) = nothing + function compute_net_ocean_fluxes!(coupled_model) - ocean = coupled_model.ocean sea_ice = coupled_model.sea_ice - grid = ocean.model.grid + grid = coupled_model.interfaces.exchanger.exchange_grid arch = architecture(grid) clock = coupled_model.clock @@ -56,13 +60,14 @@ function compute_net_ocean_fluxes!(coupled_model) freshwater_flux = atmosphere_fields.Mp.data ice_concentration = sea_ice_concentration(sea_ice) - ocean_salinity = ocean.model.tracers.S + ocean_state = get_ocean_state(coupled_model.ocean, coupled_model.interfaces.exchanger) + ocean_salinity = ocean_state.S atmos_ocean_properties = coupled_model.interfaces.atmosphere_ocean_interface.properties ocean_properties = coupled_model.interfaces.ocean_properties kernel_parameters = interface_kernel_parameters(grid) ocean_surface_temperature = coupled_model.interfaces.atmosphere_ocean_interface.temperature - penetrating_radiation = get_radiative_forcing(ocean.model.forcing.T) + penetrating_radiation = get_radiative_forcing(coupled_model.ocean) launch!(arch, grid, kernel_parameters, _assemble_net_ocean_fluxes!, @@ -80,6 +85,8 @@ function compute_net_ocean_fluxes!(coupled_model) atmos_ocean_properties, ocean_properties) + fill_net_fluxes!(coupled_model.ocean, net_ocean_fluxes) + return nothing end @@ -261,8 +268,8 @@ end @inbounds begin Ts = surface_temperature[i, j, kᴺ] Ts = convert_to_kelvin(sea_ice_properties.temperature_units, Ts) - ℵi = ice_concentration[i, j, 1] - + ℵi = ice_concentration[i, j, kᴺ] + Qs = downwelling_radiation.Qs[i, j, 1] Qℓ = downwelling_radiation.Qℓ[i, j, 1] Qc = atmosphere_sea_ice_fluxes.sensible_heat[i, j, 1] # sensible or "conductive" heat flux diff --git a/src/OceanSeaIceModels/InterfaceComputations/atmosphere_ocean_fluxes.jl b/src/OceanSeaIceModels/InterfaceComputations/atmosphere_ocean_fluxes.jl index 319ba5366..041e81a43 100644 --- a/src/OceanSeaIceModels/InterfaceComputations/atmosphere_ocean_fluxes.jl +++ b/src/OceanSeaIceModels/InterfaceComputations/atmosphere_ocean_fluxes.jl @@ -4,19 +4,21 @@ using ClimaOcean.OceanSeaIceModels.PrescribedAtmospheres: thermodynamics_paramet surface_layer_height, boundary_layer_height -function compute_atmosphere_ocean_fluxes!(coupled_model) - ocean = coupled_model.ocean - atmosphere = coupled_model.atmosphere - grid = ocean.model.grid - arch = architecture(grid) - clock = coupled_model.clock - - ocean_state = (u = ocean.model.velocities.u, +@inline get_ocean_state(ocean::OceananigansSimulation, coupled_model) = + (u = ocean.model.velocities.u, v = ocean.model.velocities.v, T = ocean.model.tracers.T, S = ocean.model.tracers.S) - atmosphere_fields = coupled_model.interfaces.exchanger.exchange_atmosphere_state +function compute_atmosphere_ocean_fluxes!(coupled_model) + ocean = coupled_model.ocean + atmosphere = coupled_model.atmosphere + exchanger = coupled_model.interfaces.exchanger + grid = exchanger.exchange_grid + arch = architecture(grid) + clock = coupled_model.clock + ocean_state = get_ocean_state(ocean, exchanger) + atmosphere_fields = exchanger.exchange_atmosphere_state # Simplify NamedTuple to reduce parameter space consumption. # See https://github.com/CliMA/ClimaOcean.jl/issues/116. diff --git a/src/OceanSeaIceModels/InterfaceComputations/component_interfaces.jl b/src/OceanSeaIceModels/InterfaceComputations/component_interfaces.jl index 52ba53207..129a9cc8e 100644 --- a/src/OceanSeaIceModels/InterfaceComputations/component_interfaces.jl +++ b/src/OceanSeaIceModels/InterfaceComputations/component_interfaces.jl @@ -8,7 +8,8 @@ using ..OceanSeaIceModels: reference_density, sea_ice_thickness, downwelling_radiation, freshwater_flux, - SeaIceSimulation + SeaIceSimulation, + OceananigansSimulation using ..OceanSeaIceModels.PrescribedAtmospheres: PrescribedAtmosphere, @@ -17,7 +18,7 @@ using ..OceanSeaIceModels.PrescribedAtmospheres: using ClimaSeaIce: SeaIceModel using Oceananigans: HydrostaticFreeSurfaceModel, architecture -using Oceananigans.Grids: inactive_node, node, topology +using Oceananigans.Grids: inactive_node, node, topology, AbstractGrid using Oceananigans.BoundaryConditions: fill_halo_regions! using Oceananigans.Fields: ConstantField, interpolate, FractionalIndices using Oceananigans.Utils: launch!, Time, KernelParameters @@ -88,10 +89,9 @@ ExchangeAtmosphereState(grid) = ExchangeAtmosphereState(Field{Center, Center, No fractional_index_type(FT, Topo) = FT fractional_index_type(FT, ::Flat) = Nothing +state_exchanger(ocean::Simulation, ::Nothing) = nothing -StateExchanger(ocean::Simulation, ::Nothing) = nothing - -function StateExchanger(ocean::Simulation, atmosphere) +function state_exchanger(ocean::Simulation, atmosphere) # TODO: generalize this exchange_grid = ocean.model.grid exchange_atmosphere_state = ExchangeAtmosphereState(exchange_grid) @@ -118,12 +118,11 @@ function atmosphere_exchanger(atmosphere::PrescribedAtmosphere, exchange_grid) end initialize!(exchanger::StateExchanger, ::Nothing) = nothing +initialize!(exchanger::StateExchanger, atmosphere) = initialize!(exchanger.atmosphere_exchanger, exchanger.exchange_grid, atmosphere) -function initialize!(exchanger::StateExchanger, atmosphere) +function initialize!(frac_indices, exchange_grid::AbstractGrid, atmosphere) atmos_grid = atmosphere.grid - exchange_grid = exchanger.exchange_grid arch = architecture(exchange_grid) - frac_indices = exchanger.atmosphere_exchanger kernel_parameters = interface_kernel_parameters(exchange_grid) launch!(arch, exchange_grid, kernel_parameters, _compute_fractional_indices!, frac_indices, exchange_grid, atmos_grid) @@ -167,26 +166,28 @@ Base.summary(crf::ComponentInterfaces) = "ComponentInterfaces" Base.show(io::IO, crf::ComponentInterfaces) = print(io, summary(crf)) atmosphere_ocean_interface(::Nothing, args...) = nothing +atmosphere_ocean_interface(ocean::OceananigansSimulation, args...) = + atmosphere_ocean_interface(ocean.model.grid, args...) -function atmosphere_ocean_interface(atmos, - ocean, +function atmosphere_ocean_interface(grid::AbstractGrid, + atmos, radiation, ao_flux_formulation, temperature_formulation, velocity_formulation, specific_humidity_formulation) - water_vapor = Field{Center, Center, Nothing}(ocean.model.grid) - latent_heat = Field{Center, Center, Nothing}(ocean.model.grid) - sensible_heat = Field{Center, Center, Nothing}(ocean.model.grid) - x_momentum = Field{Center, Center, Nothing}(ocean.model.grid) - y_momentum = Field{Center, Center, Nothing}(ocean.model.grid) - friction_velocity = Field{Center, Center, Nothing}(ocean.model.grid) - temperature_scale = Field{Center, Center, Nothing}(ocean.model.grid) - water_vapor_scale = Field{Center, Center, Nothing}(ocean.model.grid) - upwelling_longwave = Field{Center, Center, Nothing}(ocean.model.grid) - downwelling_longwave = Field{Center, Center, Nothing}(ocean.model.grid) - downwelling_shortwave = Field{Center, Center, Nothing}(ocean.model.grid) + water_vapor = Field{Center, Center, Nothing}(grid) + latent_heat = Field{Center, Center, Nothing}(grid) + sensible_heat = Field{Center, Center, Nothing}(grid) + x_momentum = Field{Center, Center, Nothing}(grid) + y_momentum = Field{Center, Center, Nothing}(grid) + friction_velocity = Field{Center, Center, Nothing}(grid) + temperature_scale = Field{Center, Center, Nothing}(grid) + water_vapor_scale = Field{Center, Center, Nothing}(grid) + upwelling_longwave = Field{Center, Center, Nothing}(grid) + downwelling_longwave = Field{Center, Center, Nothing}(grid) + downwelling_shortwave = Field{Center, Center, Nothing}(grid) ao_fluxes = (; latent_heat, sensible_heat, @@ -210,7 +211,7 @@ function atmosphere_ocean_interface(atmos, temperature_formulation, velocity_formulation) - interface_temperature = Field{Center, Center, Nothing}(ocean.model.grid) + interface_temperature = Field{Center, Center, Nothing}(grid) return AtmosphereInterface(ao_fluxes, ao_flux_formulation, interface_temperature, ao_properties) end @@ -252,16 +253,18 @@ function atmosphere_sea_ice_interface(atmos, return AtmosphereInterface(fluxes, ai_flux_formulation, interface_temperature, properties) end -sea_ice_ocean_interface(sea_ice, ocean) = nothing +sea_ice_ocean_interface(ocean, sea_ice) = nothing +sea_ice_ocean_interface(ocean, sea_ice::SeaIceSimulation; kw...) = + sea_ice_ocean_interface(ocean.grid, sea_ice; kw...) -function sea_ice_ocean_interface(sea_ice::SeaIceSimulation, ocean; +function sea_ice_ocean_interface(grid::AbstractGrid, sea_ice::SeaIceSimulation; characteristic_melting_speed = 1e-5) - io_bottom_heat_flux = Field{Center, Center, Nothing}(ocean.model.grid) - io_frazil_heat_flux = Field{Center, Center, Nothing}(ocean.model.grid) - io_salt_flux = Field{Center, Center, Nothing}(ocean.model.grid) - x_momentum = Field{Face, Center, Nothing}(ocean.model.grid) - y_momentum = Field{Center, Face, Nothing}(ocean.model.grid) + io_bottom_heat_flux = Field{Center, Center, Nothing}(grid) + io_frazil_heat_flux = Field{Center, Center, Nothing}(grid) + io_salt_flux = Field{Center, Center, Nothing}(grid) + x_momentum = Field{Face, Center, Nothing}(grid) + y_momentum = Field{Center, Face, Nothing}(grid) @assert io_frazil_heat_flux isa Field{Center, Center, Nothing} @assert io_bottom_heat_flux isa Field{Center, Center, Nothing} @@ -286,12 +289,23 @@ function default_ai_temperature(sea_ice::SeaIceSimulation) end function default_ao_specific_humidity(ocean) - FT = eltype(ocean.model.grid) + FT = eltype(ocean) phase = AtmosphericThermodynamics.Liquid() x_H₂O = convert(FT, 0.98) return ImpureSaturationSpecificHumidity(phase, x_H₂O) end +function ocean_surface_fluxes(ocean::OceananigansSimulation, ρₒ, cₒ) + τx = surface_flux(ocean.model.velocities.u) + τy = surface_flux(ocean.model.velocities.v) + tracers = ocean.model.tracers + Qₒ = ρₒ * cₒ * surface_flux(ocean.model.tracers.T) + net_ocean_surface_fluxes = (u=τx, v=τy, Q=Qₒ) + + ocean_surface_tracer_fluxes = NamedTuple(name => surface_flux(tracers[name]) for name in keys(tracers)) + return merge(ocean_surface_tracer_fluxes, net_ocean_surface_fluxes) +end + """ ComponentInterfaces(atmosphere, ocean, sea_ice=nothing; radiation = Radiation(), @@ -312,8 +326,8 @@ end function ComponentInterfaces(atmosphere, ocean, sea_ice=nothing; radiation = Radiation(), freshwater_density = 1000, - atmosphere_ocean_fluxes = SimilarityTheoryFluxes(eltype(ocean.model.grid)), - atmosphere_sea_ice_fluxes = SimilarityTheoryFluxes(eltype(ocean.model.grid)), + atmosphere_ocean_fluxes = SimilarityTheoryFluxes(), + atmosphere_sea_ice_fluxes = SimilarityTheoryFluxes(), atmosphere_ocean_interface_temperature = BulkTemperature(), atmosphere_ocean_velocity_difference = RelativeVelocity(), atmosphere_ocean_interface_specific_humidity = default_ao_specific_humidity(ocean), @@ -327,8 +341,7 @@ function ComponentInterfaces(atmosphere, ocean, sea_ice=nothing; sea_ice_heat_capacity = heat_capacity(sea_ice), gravitational_acceleration = g_Earth) - ocean_grid = ocean.model.grid - FT = eltype(ocean_grid) + FT = eltype(ocean) ocean_reference_density = convert(FT, ocean_reference_density) ocean_heat_capacity = convert(FT, ocean_heat_capacity) @@ -344,8 +357,8 @@ function ComponentInterfaces(atmosphere, ocean, sea_ice=nothing; freshwater_density = freshwater_density, temperature_units = ocean_temperature_units) - ao_interface = atmosphere_ocean_interface(atmosphere, - ocean, + ao_interface = atmosphere_ocean_interface(ocean, + atmosphere, radiation, atmosphere_ocean_fluxes, atmosphere_ocean_interface_temperature, @@ -386,23 +399,14 @@ function ComponentInterfaces(atmosphere, ocean, sea_ice=nothing; net_bottom_sea_ice_fluxes = nothing end - τx = surface_flux(ocean.model.velocities.u) - τy = surface_flux(ocean.model.velocities.v) - tracers = ocean.model.tracers - ρₒ = ocean_reference_density - cₒ = ocean_heat_capacity - Qₒ = ρₒ * cₒ * surface_flux(ocean.model.tracers.T) - net_ocean_surface_fluxes = (u=τx, v=τy, Q=Qₒ) - - ocean_surface_tracer_fluxes = NamedTuple(name => surface_flux(tracers[name]) for name in keys(tracers)) - net_ocean_surface_fluxes = merge(ocean_surface_tracer_fluxes, net_ocean_surface_fluxes) + net_ocean_surface_fluxes = ocean_surface_fluxes(ocean, ocean_reference_density, ocean_heat_capacity) # Total interface fluxes net_fluxes = (ocean_surface = net_ocean_surface_fluxes, sea_ice_top = net_top_sea_ice_fluxes, sea_ice_bottom = net_bottom_sea_ice_fluxes) - exchanger = StateExchanger(ocean, atmosphere) + exchanger = state_exchanger(ocean, atmosphere) properties = (; gravitational_acceleration) diff --git a/src/OceanSeaIceModels/InterfaceComputations/interpolate_atmospheric_state.jl b/src/OceanSeaIceModels/InterfaceComputations/interpolate_atmospheric_state.jl index 3c8ffa0b9..afd63817b 100644 --- a/src/OceanSeaIceModels/InterfaceComputations/interpolate_atmospheric_state.jl +++ b/src/OceanSeaIceModels/InterfaceComputations/interpolate_atmospheric_state.jl @@ -26,7 +26,7 @@ function interpolate_atmosphere_state!(interfaces, atmosphere::PrescribedAtmosph atmosphere_grid = atmosphere.grid # Basic model properties - grid = ocean.model.grid + grid = coupled_model.interfaces.exchanger.exchange_grid arch = architecture(grid) clock = coupled_model.clock @@ -125,8 +125,8 @@ function interpolate_atmosphere_state!(interfaces, atmosphere::PrescribedAtmosph # # TODO: find a better design for this that doesn't have redundant # arrays for the barotropic potential - u_potential = forcing_barotropic_potential(ocean.model.forcing.u) - v_potential = forcing_barotropic_potential(ocean.model.forcing.v) + u_potential = forcing_barotropic_potential(ocean) + v_potential = forcing_barotropic_potential(ocean) ρₒ = coupled_model.interfaces.ocean_properties.reference_density if !isnothing(u_potential) diff --git a/src/OceanSeaIceModels/InterfaceComputations/sea_ice_ocean_fluxes.jl b/src/OceanSeaIceModels/InterfaceComputations/sea_ice_ocean_fluxes.jl index 4d37e1d86..95ecb7c33 100644 --- a/src/OceanSeaIceModels/InterfaceComputations/sea_ice_ocean_fluxes.jl +++ b/src/OceanSeaIceModels/InterfaceComputations/sea_ice_ocean_fluxes.jl @@ -23,10 +23,13 @@ function compute_sea_ice_ocean_fluxes!(sea_ice_ocean_fluxes, ocean, sea_ice, mel ℵᵢ = sea_ice.model.ice_concentration hᵢ = sea_ice.model.ice_thickness Gh = sea_ice.model.ice_thermodynamics.thermodynamic_tendency + Δt = sea_ice.Δt + + ocean_state = get_ocean_state(ocean, coupled_model) liquidus = sea_ice.model.ice_thermodynamics.phase_transitions.liquidus - grid = ocean.model.grid - clock = ocean.model.clock + grid = sea_ice.model.grid + clock = sea_ice.model.clock arch = architecture(grid) uᵢ, vᵢ = sea_ice.model.velocities @@ -41,7 +44,7 @@ function compute_sea_ice_ocean_fluxes!(sea_ice_ocean_fluxes, ocean, sea_ice, mel # What about the latent heat removed from the ocean when ice forms? # Is it immediately removed from the ocean? Or is it stored in the ice? launch!(arch, grid, :xy, _compute_sea_ice_ocean_fluxes!, - sea_ice_ocean_fluxes, grid, clock, hᵢ, ℵᵢ, Sᵢ, Gh, Tₒ, Sₒ, uᵢ, vᵢ, + sea_ice_ocean_fluxes, grid, clock, hᵢ, ℵᵢ, Sᵢ, Gh, ocean_state, uᵢ, vᵢ, τs, liquidus, ocean_properties, melting_speed, Δt) return nothing @@ -54,8 +57,7 @@ end ice_concentration, ice_salinity, thermodynamic_tendency, - ocean_temperature, - ocean_salinity, + ocean_state, sea_ice_u_velocity, sea_ice_v_velocity, sea_ice_ocean_stresses, @@ -74,8 +76,8 @@ end τy = sea_ice_ocean_fluxes.y_momentum uᵢ = sea_ice_u_velocity vᵢ = sea_ice_v_velocity - Tₒ = ocean_temperature - Sₒ = ocean_salinity + Tₒ = ocean_state.T + Sₒ = ocean_state.S Sᵢ = ice_salinity hᵢ = ice_thickness ℵᵢ = ice_concentration diff --git a/src/OceanSeaIceModels/OceanSeaIceModels.jl b/src/OceanSeaIceModels/OceanSeaIceModels.jl index d44f7429d..78f139a54 100644 --- a/src/OceanSeaIceModels/OceanSeaIceModels.jl +++ b/src/OceanSeaIceModels/OceanSeaIceModels.jl @@ -41,6 +41,9 @@ const default_gravitational_acceleration = 9.80665 const default_freshwater_density = 1000 const SeaIceSimulation = Simulation{<:SeaIceModel} +const OceananigansSimulation = Simulation{<:HydrostaticFreeSurfaceModel} + +Base.eltype(ocean::OceananigansSimulation) = eltype(ocean.model.grid) sea_ice_thickness(::Nothing) = ZeroField() sea_ice_thickness(sea_ice::SeaIceSimulation) = sea_ice.model.ice_thickness diff --git a/src/OceanSeaIceModels/PrescribedAtmospheres.jl b/src/OceanSeaIceModels/PrescribedAtmospheres.jl index e7a3ca18c..15764b14c 100644 --- a/src/OceanSeaIceModels/PrescribedAtmospheres.jl +++ b/src/OceanSeaIceModels/PrescribedAtmospheres.jl @@ -5,6 +5,7 @@ using Oceananigans.Fields: Center using Oceananigans.Grids: grid_name using Oceananigans.OutputReaders: FieldTimeSeries, update_field_time_series!, extract_field_time_series using Oceananigans.TimeSteppers: Clock, tick! +using Oceananigans.Simulations: TimeStepWizard using Oceananigans.Utils: prettysummary, Time using Adapt @@ -371,6 +372,8 @@ end return nothing end +(wizard::TimeStepWizard)(atmos::PrescribedAtmosphere) = Inf + @inline thermodynamics_parameters(atmos::Nothing) = nothing @inline thermodynamics_parameters(atmos::PrescribedAtmosphere) = atmos.thermodynamics_parameters @inline surface_layer_height(atmos::PrescribedAtmosphere) = atmos.surface_layer_height diff --git a/src/OceanSeaIceModels/ocean_sea_ice_model.jl b/src/OceanSeaIceModels/ocean_sea_ice_model.jl index 9a4cd0f7a..6a64aa10e 100644 --- a/src/OceanSeaIceModels/ocean_sea_ice_model.jl +++ b/src/OceanSeaIceModels/ocean_sea_ice_model.jl @@ -19,6 +19,8 @@ import Oceananigans.TimeSteppers: time_step!, update_state!, time import Oceananigans.Utils: prettytime import Oceananigans.Models: timestepper, NaNChecker, default_nan_checker, initialization_update_state! +import Base + mutable struct OceanSeaIceModel{I, A, O, F, C, Arch} <: AbstractModel{Nothing, Arch} architecture :: Arch clock :: C @@ -45,7 +47,7 @@ function Base.show(io::IO, cm::OSIM) end print(io, summary(cm), "\n") - print(io, "├── ocean: ", summary(cm.ocean.model), "\n") + print(io, "├── ocean: ", summary(cm.ocean), "\n") print(io, "├── atmosphere: ", summary(cm.atmosphere), "\n") print(io, "├── sea_ice: ", sea_ice_summary, "\n") print(io, "└── interfaces: ", summary(cm.interfaces))