Skip to content

Commit abda75f

Browse files
Locals (NOAA-GFDL#266)
* Temporaries base dataclass Allow `units` to not be specified + unit test * Lint * Hide `_transient` flag of Quantity away * `QuantityFactory` has now a `is_local` option * Update wording to `Local`, trash `Temporaries` state idea * lint * Public API clean up * Simplify and fix test * Oops, restore code for `from_array` allocator * Remove keyword allocation * Introduce `Local` and `NDSLRuntime` * Lint new files * Remove the odd `_transient` and tag transientness correctly in `Local` * Repeat of the Quantity trick to get a proper type hint * Lint `local.py` * Protect against bad init for orchestration Move all unit test into a `test_ndsl_runtime` * Lint * Revert uneeded change to `Quantity` * Correct type hint for Callable * Revert orthogonal changes to this PR * Lint --------- Co-authored-by: Tobias Wicky-Pfund <tobias.wicky@meteoswiss.ch>
1 parent f10c47d commit abda75f

File tree

7 files changed

+315
-11
lines changed

7 files changed

+315
-11
lines changed

ndsl/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
StorageReport,
1616
)
1717
from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater
18+
from .dsl.ndsl_runtime import NDSLRuntime
1819
from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector
1920
from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
2021
from .exceptions import OutOfBoundsError
@@ -27,7 +28,7 @@
2728
from .performance.collector import NullPerformanceCollector, PerformanceCollector
2829
from .performance.profiler import NullProfiler, Profiler
2930
from .performance.report import Experiment, Report, TimeReport
30-
from .quantity import Quantity, State
31+
from .quantity import Local, Quantity, State
3132
from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import
3233
from .testing.dummy_comm import DummyComm
3334
from .types import Allocator
@@ -86,4 +87,6 @@
8687
"Allocator",
8788
"MetaEnumStr",
8889
"State",
90+
"NDSLRuntime",
91+
"Local",
8992
]

ndsl/dsl/dace/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .dace_config import DaceConfig
2+
from .orchestration import orchestrate, orchestrate_function
3+
4+
5+
__all__ = ["DaceConfig", "orchestrate", "orchestrate_function"]

ndsl/dsl/ndsl_runtime.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
import warnings
5+
from collections.abc import Callable
6+
from typing import Any
7+
8+
from ndsl.dsl.dace import DaceConfig, orchestrate
9+
from ndsl.dsl.typing import Float
10+
from ndsl.initialization.allocator import QuantityFactory
11+
from ndsl.quantity import Local, Quantity
12+
13+
14+
_TOP_LEVEL: object | None = None
15+
16+
17+
class NDSLRuntime:
18+
"""Base class to tool runtime code, allows use of Locals, orchestration and
19+
debug tools.
20+
21+
The __call__ function will automatically be orchestrated."""
22+
23+
def __init__(self, dace_config: DaceConfig) -> None:
24+
self._dace_config = dace_config
25+
# Use this flag to detect that the init wasn't done properly
26+
self._base_class_was_properly_super_init = True
27+
28+
def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None:
29+
# WARNING: no code outside the `init_decorator` this is cls
30+
# function, it will be called ONLY ONCE for monkey-patching the
31+
# Class - not the instance !
32+
33+
def init_decorator(previous_init: Callable) -> Callable:
34+
def new_init(
35+
self: NDSLRuntime,
36+
*args: list[Any],
37+
**kwargs: dict[str, Any],
38+
) -> None:
39+
global _TOP_LEVEL
40+
if _TOP_LEVEL is None:
41+
_TOP_LEVEL = self
42+
previous_init(self, *args, **kwargs)
43+
self.__post_init__()
44+
45+
return new_init
46+
47+
cls.__init__ = init_decorator(cls.__init__) # type: ignore[method-assign]
48+
49+
def __post_init__(self) -> None:
50+
if not hasattr(self, "_base_class_was_properly_super_init"):
51+
raise RuntimeError(
52+
f"Class {type(self).__name__} inherit from NDSLRuntime but didn't call super().__init__."
53+
)
54+
55+
# Check quantity allocation of NDSLRuntime supervised code
56+
if _TOP_LEVEL == self:
57+
58+
def check_for_quantity(object_: object) -> None:
59+
for key, value in object_.__dict__.items():
60+
if isinstance(value, Quantity) and not isinstance(value, Local):
61+
warnings.warn(
62+
f"{type(self).__name__}.{key} is a Quantity instead of a Locals"
63+
" on a NDSLRuntime - our eyebrows are frowned."
64+
)
65+
elif isinstance(value, NDSLRuntime):
66+
check_for_quantity(value)
67+
68+
check_for_quantity(self)
69+
70+
# Orchestrate __call__ by default
71+
if hasattr(self, "__call__"):
72+
orchestrate(
73+
obj=self,
74+
config=self._dace_config,
75+
)
76+
print(type(self))
77+
78+
def __getattribute__(self, name: str) -> Any:
79+
attr = super().__getattribute__(name)
80+
# We look at the direct caller frame for our own `self`
81+
# in the locals.
82+
# All other cases are forbidden.
83+
if isinstance(attr, Local):
84+
frame = inspect.currentframe()
85+
if frame is None:
86+
raise NotImplementedError(
87+
"Locals check cannot locate frame. Talk to the team."
88+
)
89+
caller_frame = frame.f_back
90+
if (
91+
not caller_frame
92+
or "self" not in caller_frame.f_locals
93+
or not isinstance(caller_frame.f_locals["self"], type(self))
94+
):
95+
# We expect the original class to have been monkey-patched
96+
# See `dace.dsl.orchestration.orchestrate`
97+
unpatched_name = type(self).__name__[: -len("_patched")]
98+
raise RuntimeError(
99+
f"Forbidden Local access: {name} called outside of {unpatched_name}."
100+
)
101+
102+
return attr
103+
104+
def make_local(
105+
self,
106+
quantity_factory: QuantityFactory,
107+
dims: list[str],
108+
dtype: type = Float,
109+
units: str = "unspecified",
110+
*,
111+
allow_mismatch_float_precision: bool = False,
112+
) -> Local:
113+
quantity = quantity_factory.zeros(
114+
dims,
115+
units,
116+
dtype,
117+
allow_mismatch_float_precision=allow_mismatch_float_precision,
118+
)
119+
return Local(
120+
data=quantity.data,
121+
dims=quantity.dims,
122+
units=quantity.units,
123+
origin=quantity.origin,
124+
extent=quantity.extent,
125+
gt4py_backend=quantity.gt4py_backend,
126+
allow_mismatch_float_precision=allow_mismatch_float_precision,
127+
)

ndsl/initialization/allocator.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,41 +77,56 @@ def empty(
7777
dims: Sequence[str],
7878
units: str,
7979
dtype: type = Float,
80+
*,
8081
allow_mismatch_float_precision: bool = False,
8182
) -> Quantity:
8283
"""Allocate a Quantity - values are random.
8384
8485
Equivalent to `numpy.empty`"""
8586
return self._allocate(
86-
self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision
87+
self._numpy.empty,
88+
dims,
89+
units,
90+
dtype,
91+
allow_mismatch_float_precision,
8792
)
8893

8994
def zeros(
9095
self,
9196
dims: Sequence[str],
9297
units: str,
9398
dtype: type = Float,
99+
*,
94100
allow_mismatch_float_precision: bool = False,
95101
) -> Quantity:
96102
"""Allocate a Quantity and fill it with the value 0.
97103
98104
Equivalent to `numpy.zeros`"""
99105
return self._allocate(
100-
self._numpy.zeros, dims, units, dtype, allow_mismatch_float_precision
106+
self._numpy.zeros,
107+
dims,
108+
units,
109+
dtype,
110+
allow_mismatch_float_precision,
101111
)
102112

103113
def ones(
104114
self,
105115
dims: Sequence[str],
106116
units: str,
107117
dtype: type = Float,
118+
*,
108119
allow_mismatch_float_precision: bool = False,
109120
) -> Quantity:
110121
"""Allocate a Quantity and fill it with the value 1.
111122
112123
Equivalent to `numpy.ones`"""
113124
return self._allocate(
114-
self._numpy.ones, dims, units, dtype, allow_mismatch_float_precision
125+
self._numpy.ones,
126+
dims,
127+
units,
128+
dtype,
129+
allow_mismatch_float_precision,
115130
)
116131

117132
def full(
@@ -120,13 +135,18 @@ def full(
120135
units: str,
121136
value: Any, # no type hint because it would be a TypeVar = type[dtype] and mypy says no
122137
dtype: type = Float,
138+
*,
123139
allow_mismatch_float_precision: bool = False,
124140
) -> Quantity:
125141
"""Allocate a Quantity and fill it with the value.
126142
127143
Equivalent to `numpy.full`"""
128144
quantity = self._allocate(
129-
self._numpy.empty, dims, units, dtype, allow_mismatch_float_precision
145+
self._numpy.empty,
146+
dims,
147+
units,
148+
dtype,
149+
allow_mismatch_float_precision,
130150
)
131151
quantity.data[:] = value
132152
return quantity
@@ -136,6 +156,7 @@ def from_array(
136156
data: np.ndarray,
137157
dims: Sequence[str],
138158
units: str,
159+
*,
139160
allow_mismatch_float_precision: bool = False,
140161
) -> Quantity:
141162
"""
@@ -158,6 +179,7 @@ def from_compute_array(
158179
data: np.ndarray,
159180
dims: Sequence[str],
160181
units: str,
182+
*,
161183
allow_mismatch_float_precision: bool = False,
162184
) -> Quantity:
163185
"""

ndsl/quantity/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from .state import State
44

55

6-
__all__ = [
7-
"Quantity",
8-
"QuantityMetadata",
9-
"QuantityHaloSpec",
10-
"State",
11-
]
6+
from .local import Local # isort: skip
7+
8+
9+
__all__ = ["Local", "Quantity", "QuantityMetadata", "QuantityHaloSpec", "State"]

ndsl/quantity/local.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any, Sequence
2+
3+
import dace
4+
import numpy as np
5+
6+
from ndsl.optional_imports import cupy
7+
from ndsl.quantity import Quantity
8+
9+
10+
if cupy is None:
11+
import numpy as cupy
12+
13+
14+
class Local(Quantity):
15+
"""Local is a Quantity that cannot be used outside of the class
16+
it was allocated in."""
17+
18+
def __init__(
19+
self,
20+
data: np.ndarray | cupy.ndarray,
21+
dims: Sequence[str],
22+
units: str,
23+
origin: Sequence[int] | None = None,
24+
extent: Sequence[int] | None = None,
25+
gt4py_backend: str | None = None,
26+
allow_mismatch_float_precision: bool = False,
27+
):
28+
super().__init__(
29+
data,
30+
dims,
31+
units,
32+
origin,
33+
extent,
34+
gt4py_backend,
35+
allow_mismatch_float_precision,
36+
)
37+
self._transient = True
38+
39+
def __descriptor__(self) -> Any:
40+
"""Locals uses `Quantity.__descriptor__` and flag itself as transient."""
41+
data = dace.data.create_datadescriptor(self.data)
42+
data.transient = True
43+
return data

0 commit comments

Comments
 (0)