|
19 | 19 | Define Forecast base class. |
20 | 20 | """ |
21 | 21 |
|
22 | | -from abc import ABC, abstractmethod |
23 | | -from typing import Literal, Self |
24 | 22 |
|
25 | | -import numpy as np |
26 | | -from scipy.sparse import block_array, csr_matrix |
27 | | - |
28 | | - |
29 | | -def reduce_mat(mat: csr_matrix, func): |
30 | | - """Reduce a matrix and return the CSR representation""" |
31 | | - return func(mat).tocsr() |
32 | | - |
33 | | - |
34 | | -def concat_matrices_per_event(*matrices: csr_matrix): |
35 | | - """Concatenate matrices by event""" |
36 | | - return block_array([[mat] for mat in matrices], format="csr") |
37 | | - |
38 | | - |
39 | | -def matrix_event_padding(mat: csr_matrix, num_events): |
40 | | - """Pad zero events""" |
41 | | - pad_events = mat.shape[0] - num_events |
42 | | - if pad_events < 1: |
43 | | - return mat |
44 | | - return block_array( |
45 | | - [[mat], csr_matrix((pad_events, mat.shape[1], mat.dtype))], format="csr" |
46 | | - ) |
47 | | - |
48 | | - |
49 | | -class Forecast(ABC): |
50 | | - lead_time: np.ndarray[np.timedelta64] |
51 | | - member: np.ndarray[int] |
52 | | - forecast_date: np.datetime64 | None |
53 | | - num_members: int |
54 | | - num_lead_times: int |
55 | | - |
56 | | - def __init__(self, lead_time, member, forecast_date: np.datetime64 | None = None): |
57 | | - """Store members""" |
58 | | - pass |
59 | | - |
60 | | - # --- Selection --- # |
61 | | - |
62 | | - @abstractmethod |
63 | | - def _select_by_index(self, index: tuple[np.ndarray, ...]) -> Self: |
64 | | - """Return a new object with the index used for selecting events""" |
65 | | - raise NotImplementedError |
66 | | - |
67 | | - def _select_member(self, member: int | None) -> np.ndarray: |
68 | | - """Return boolean array where self.member == member""" |
69 | | - ... |
70 | | - |
71 | | - def _select_lead_time(self, lead_time: np.timedelta64 | None) -> np.ndarray: |
72 | | - """Return boolean array where self.lead_time == lead_time""" |
73 | | - ... |
74 | | - |
75 | | - def select(self, *, member: int | None, lead_time: np.timedelta64 | None) -> Self: |
76 | | - index = np.nonzero( |
77 | | - self._select_member(member) & self._select_lead_time(lead_time) |
78 | | - ) |
79 | | - return self._select_by_index(index) |
80 | | - |
81 | | - # --- Generic reduction --- # |
82 | | - |
83 | | - @classmethod |
84 | | - @abstractmethod |
85 | | - def concat(cls, *obj: Self) -> Self: |
86 | | - """Concatenate multiple object instances""" |
87 | | - raise NotImplementedError |
88 | | - |
89 | | - @abstractmethod |
90 | | - def _reduce(self, func) -> Self: |
91 | | - """Apply the reduction function in the derived class and return the result |
92 | | -
|
93 | | - Note: The derived class will likely need to pad matrices! |
94 | | - """ |
95 | | - raise NotImplementedError |
96 | | - |
97 | | - def reduce(self, func, dim: Literal["member", "lead_time"] | None = None) -> Self: |
98 | | - """Reduce along a given dimension with func""" |
99 | | - if dim is None: |
100 | | - # TODO: Check if we selected a specific member or lead time. |
101 | | - # Pad events accordingly! |
102 | | - return self._reduce(func=func) # Derived class specialization |
103 | | - |
104 | | - return self.concat( |
105 | | - *( |
106 | | - self.select(**{dim: val}).reduce(func=func, dim=None) |
107 | | - for val in np.unique(getattr(self, dim)) |
108 | | - ) |
109 | | - ) |
110 | | - |
111 | | - # --- Specializations --- # |
112 | | - |
113 | | - @abstractmethod |
114 | | - def _max(self) -> Self: |
115 | | - """Apply the maximum function in the derived class and return the result""" |
116 | | - raise NotImplementedError |
117 | | - |
118 | | - def _reduce_attr( |
119 | | - self, attr: str, dim: Literal["member", "lead_time"] | None = None |
120 | | - ) -> Self: |
121 | | - """Reduce along a given dimension with attribute attr""" |
122 | | - if dim is None: |
123 | | - # TODO: Check if we selected a specific member or lead time. |
124 | | - # Pad events accordingly! |
125 | | - return getattr(self, "_" + attr)() # Derived class specialization |
126 | | - |
127 | | - return self.concat( |
128 | | - *( |
129 | | - getattr(self.select(**{dim: val}), attr)(dim=None) |
130 | | - for val in np.unique(getattr(self, dim)) |
131 | | - ) |
132 | | - ) |
133 | | - |
134 | | - def max(self, dim): |
135 | | - return self._reduce_attr(attr="max", dim=dim) |
| 23 | +class Forecast: |
| 24 | + pass |
0 commit comments