Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDSL
16 changes: 16 additions & 0 deletions pace/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from datetime import datetime, timedelta
from typing import List, Optional, Union

import numpy as np

from ndsl import Quantity
from ndsl.constants import RGRAV, Z_DIM, Z_INTERFACE_DIM
from ndsl.dsl.dace.orchestration import dace_inhibitor
from ndsl.dsl.typing import Float
from ndsl.filesystem import get_fs
from ndsl.grid import GridData
from ndsl.monitor import Monitor, ZarrMonitor
Expand Down Expand Up @@ -86,6 +89,7 @@ class DiagnosticsConfig:
names: List[str] = dataclasses.field(default_factory=list)
derived_names: List[str] = dataclasses.field(default_factory=list)
z_select: List[ZSelect] = dataclasses.field(default_factory=list)
precision: str = "Float"

def __post_init__(self):
if (len(self.names) > 0 or len(self.derived_names) > 0) and self.path is None:
Expand All @@ -97,6 +101,11 @@ def __post_init__(self):
"output_format must be one of 'zarr' or 'netcdf', "
f"got {self.output_format}"
)
if self.precision not in ["Float", "float32", "float64"]:
raise ValueError(
"precision must be one of 'Float', 'float32', or 'float64"
f"got {self.precision}"
)

def diagnostics_factory(self, communicator: Communicator) -> Diagnostics:
"""
Expand All @@ -120,10 +129,17 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics:
mpi_comm=communicator.comm,
)
elif self.output_format == "netcdf":
if self.precision == "Float":
precision = Float
elif self.precision == "float32":
precision = np.float32
elif self.precision == "float64":
precision = np.float64
monitor = NetCDFMonitor(
path=self.path,
communicator=communicator,
time_chunk_size=self.time_chunk_size,
precision=precision,
)
else:
raise ValueError(
Expand Down