Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gt4py.cartesian.gtscript import PARALLEL, computation, exp, float64, interval, log, sqrt

import pyMoist.constants as constants
from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, Int
from pyMoist.field_types import FloatField_NModes
Expand All @@ -21,8 +21,8 @@
DENSIC = float64(917.0) # Ice crystal density in kgm-3

# Default precision
NN_MIN = 100.0e6
NN_MAX = 1000.0e6
NN_MIN = Float(100.0e6)
NN_MAX = Float(1000.0e6)

# ACTFRAC_Mat constants - all 64 bit
PI = float64(3.141592653589793e00)
Expand Down Expand Up @@ -282,7 +282,7 @@ def aer_activation_stencil(
nacti = NN_MAX


class AerActivation:
class AerActivation(NDSLRuntime):
"""
Class for aerosol activation computation.

Expand Down Expand Up @@ -313,48 +313,28 @@ def __init__(
NotImplementedError: If the number of modes is not equal to the expected number.
NotImplementedError: If the neural network for aerosol is not implemented.
"""
orchestrate(obj=self, config=stencil_factory.config.dace_config)
super().__init__(dace_config=stencil_factory.config.dace_config)

if constants.N_MODES != n_modes:
raise NotImplementedError(
f"Coding limitation: {constants.N_MODES} modes are expected, " f"getting {n_modes}"
f"Coding limitation: {constants.N_MODES} modes are expected, getting {n_modes}"
)

if not USE_AERSOL_NN:
raise NotImplementedError("Non NN Aerosol not implemented")

# Temporary buffers
# Locals
quantity_factory.add_data_dimensions(
**{
{
"n_modes": constants.N_MODES,
}
)

self._nact = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
units="n/a",
dtype=Float,
)
self._ni = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
units="n/a",
dtype=Float,
)
self._rg = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
units="n/a",
dtype=Float,
)
self._sig0 = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
units="n/a",
dtype=Float,
)
self._bibar = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
units="n/a",
dtype=Float,
)
self._nact = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM, "n_modes"])
self._ni = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM, "n_modes"])
self._rg = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM, "n_modes"])
self._sig0 = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM, "n_modes"])
self._bibar = self.make_local(quantity_factory, [X_DIM, Y_DIM, Z_DIM, "n_modes"])

# Stencil
self.aer_activation = stencil_factory.from_dims_halo(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

from f90nml import Namelist

from ndsl import Quantity, StencilFactory
Expand All @@ -14,8 +16,6 @@ def __init__(self, grid, namelist: Namelist, stencil_factory: StencilFactory):
self.quantity_factory = grid.quantity_factory
self._grid = grid

self.nmodes_quantity_factory = AerActivation.make_nmodes_quantity_factory(self.quantity_factory)

# FloatField Inputs
self.in_vars["data_vars"] = {
"AERO_F_DUST": {},
Expand Down Expand Up @@ -77,7 +77,7 @@ def make_ij_field(self, data) -> Quantity:
return qty

def make_nmodes_ijk_field(self, data) -> Quantity:
qty = self.nmodes_quantity_factory.empty(
qty = self.quantity_factory.empty(
[X_DIM, Y_DIM, Z_DIM, "n_modes"],
"n/a",
)
Expand Down Expand Up @@ -116,7 +116,6 @@ def compute(self, inputs):
qlcn = self.make_ijk_field(inputs["QLCN"])
qicn = self.make_ijk_field(inputs["QICN"])

n_modes = inputs["n_modes"]
ccn_lnd = Float(inputs["CCN_LND"])
ccn_ocn = Float(inputs["CCN_OCN"])

Expand Down Expand Up @@ -155,8 +154,40 @@ def compute(self, inputs):
aero_sigma=aero_sigma,
)

return {
"NACTL": nactl.view[:, :, :],
"NACTI": nacti.view[:, :, :],
"NWFA": nwfa.view[:, :, :],
output = {
"NACTL": nactl.field[:, :, :].copy(),
"NACTI": nacti.field[:, :, :].copy(),
"NWFA": nwfa.field[:, :, :].copy(),
}

# Inline benchmarking - because Aer Activation is a small enough code

s = time.perf_counter()

bench_runs = 100
for _ in range(0, bench_runs):
aer_activation(
aero_dgn=aero_dgn,
aero_num=aero_num,
nacti=nacti,
t=t,
plo=plmb,
qicn=qicn,
qils=qils,
qlcn=qlcn,
qlls=qlls,
nn_land=Float(ccn_lnd * 1.0e6),
frland=frland,
nn_ocean=Float(ccn_ocn * 1.0e6),
aero_hygroscopicity=aero_hygroscopicity,
nwfa=nwfa,
nactl=nactl,
vvel=tmp3d,
tke=tke,
aero_sigma=aero_sigma,
)

e = time.perf_counter()
print(f"Aer Activation inline bench: {e - s:.2f}s for {bench_runs} tries")

return output