diff --git a/NDSL b/NDSL index c93da232..0740e26e 160000 --- a/NDSL +++ b/NDSL @@ -1 +1 @@ -Subproject commit c93da232bf7ab7020b57722b454f8eaeb2ff8340 +Subproject commit 0740e26ef974a50bd29372600294d72acc9f98c8 diff --git a/pace/diagnostics.py b/pace/diagnostics.py index b395b353..9b38d621 100644 --- a/pace/diagnostics.py +++ b/pace/diagnostics.py @@ -4,9 +4,12 @@ from datetime import datetime, timedelta from typing import List, Optional, Union +import numpy as np + from ndsl import Quantity from ndsl.constants import RGRAV, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.dace.orchestration import dace_inhibitor +from ndsl.dsl.typing import Float from ndsl.filesystem import get_fs from ndsl.grid import GridData from ndsl.monitor import Monitor, ZarrMonitor @@ -86,6 +89,7 @@ class DiagnosticsConfig: names: List[str] = dataclasses.field(default_factory=list) derived_names: List[str] = dataclasses.field(default_factory=list) z_select: List[ZSelect] = dataclasses.field(default_factory=list) + precision: str = "Float" def __post_init__(self): if (len(self.names) > 0 or len(self.derived_names) > 0) and self.path is None: @@ -97,6 +101,11 @@ def __post_init__(self): "output_format must be one of 'zarr' or 'netcdf', " f"got {self.output_format}" ) + if self.precision not in ["Float", "float32", "float64"]: + raise ValueError( + "precision must be one of 'Float', 'float32', or 'float64" + f"got {self.precision}" + ) def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: """ @@ -120,10 +129,17 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: mpi_comm=communicator.comm, ) elif self.output_format == "netcdf": + if self.precision == "Float": + precision = Float + elif self.precision == "float32": + precision = np.float32 + elif self.precision == "float64": + precision = np.float64 monitor = NetCDFMonitor( path=self.path, communicator=communicator, time_chunk_size=self.time_chunk_size, + precision=precision, ) else: raise ValueError(