From eb90cfef1ae5f878a9fd91baeab45e5094726d07 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 25 Feb 2025 12:54:19 -0500 Subject: [PATCH 1/3] Added configurable options for precision of NetCDF output --- NDSL | 2 +- pace/diagnostics.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/NDSL b/NDSL index b7db2592..a75a1d79 160000 --- a/NDSL +++ b/NDSL @@ -1 +1 @@ -Subproject commit b7db25926c9258045457c73df3b560803a90449c +Subproject commit a75a1d793a6b1c5f2a723f9eca718a6197b3b814 diff --git a/pace/diagnostics.py b/pace/diagnostics.py index b395b353..1d39902b 100644 --- a/pace/diagnostics.py +++ b/pace/diagnostics.py @@ -86,6 +86,7 @@ class DiagnosticsConfig: names: List[str] = dataclasses.field(default_factory=list) derived_names: List[str] = dataclasses.field(default_factory=list) z_select: List[ZSelect] = dataclasses.field(default_factory=list) + precision: str = "Float" def __post_init__(self): if (len(self.names) > 0 or len(self.derived_names) > 0) and self.path is None: @@ -97,6 +98,11 @@ def __post_init__(self): "output_format must be one of 'zarr' or 'netcdf', " f"got {self.output_format}" ) + if self.precision not in ["Float", "float32", "float64"]: + raise ValueError( + "precision must be one of 'Float', 'float32', or 'float64" + f"got {self.precision}" + ) def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: """ @@ -124,6 +130,7 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: path=self.path, communicator=communicator, time_chunk_size=self.time_chunk_size, + precision=self.precision, ) else: raise ValueError( From d52ca50f6f7c89f2d1826b7ba3e0b48cd766380a Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 5 Mar 2025 09:46:25 -0500 Subject: [PATCH 2/3] Added in method for setting NetCDF output precision --- pace/diagnostics.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pace/diagnostics.py b/pace/diagnostics.py index 1d39902b..9b38d621 100644 --- a/pace/diagnostics.py +++ b/pace/diagnostics.py @@ -4,9 +4,12 @@ from datetime import datetime, timedelta from typing import List, Optional, Union +import numpy as np + from ndsl import Quantity from ndsl.constants import RGRAV, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.dace.orchestration import dace_inhibitor +from ndsl.dsl.typing import Float from ndsl.filesystem import get_fs from ndsl.grid import GridData from ndsl.monitor import Monitor, ZarrMonitor @@ -126,11 +129,17 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: mpi_comm=communicator.comm, ) elif self.output_format == "netcdf": + if self.precision == "Float": + precision = Float + elif self.precision == "float32": + precision = np.float32 + elif self.precision == "float64": + precision = np.float64 monitor = NetCDFMonitor( path=self.path, communicator=communicator, time_chunk_size=self.time_chunk_size, - precision=self.precision, + precision=precision, ) else: raise ValueError( From 21a1b39085ce1e7ec74c0a5b523bd085761acd78 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 5 Mar 2025 10:52:42 -0500 Subject: [PATCH 3/3] Updated NDSL to use NetCDF precision setting feature --- NDSL | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NDSL b/NDSL index c93da232..0740e26e 160000 --- a/NDSL +++ b/NDSL @@ -1 +1 @@ -Subproject commit c93da232bf7ab7020b57722b454f8eaeb2ff8340 +Subproject commit 0740e26ef974a50bd29372600294d72acc9f98c8