|
22 | 22 |
|
23 | 23 | """Module for active stress curve.""" |
24 | 24 |
|
25 | | -from typing import Literal |
| 25 | +from typing import Literal, Tuple |
26 | 26 |
|
27 | | -import matplotlib.pyplot as plt |
28 | 27 | import numpy as np |
| 28 | +from pydantic import ( |
| 29 | + BaseModel, |
| 30 | + ConfigDict, |
| 31 | + Field, |
| 32 | + field_serializer, |
| 33 | + field_validator, |
| 34 | + model_validator, |
| 35 | +) |
29 | 36 |
|
30 | 37 | from ansys.health.heart import LOG as LOGGER |
31 | 38 |
|
@@ -126,146 +133,127 @@ def constant_ca2(tb: float = 800, ca2ionm: float = 4.35) -> tuple[np.ndarray, np |
126 | 133 | return (t, v) |
127 | 134 |
|
128 | 135 |
|
129 | | -# TODO: Use pydantic to easily (de)serialize the curve. |
130 | | -class ActiveCurve: |
131 | | - """Active stress or Ca2+ curve.""" |
132 | | - |
133 | | - def __init__( |
134 | | - self, |
135 | | - func: tuple[np.ndarray, np.ndarray], |
136 | | - type: Literal["stress", "ca2"] = "ca2", |
137 | | - threshold: float = 0.5e-6, |
138 | | - n: int = 5, |
139 | | - ) -> None: |
140 | | - """Define a curve for active behavior of MAT295. |
141 | | -
|
142 | | - Parameters |
143 | | - ---------- |
144 | | - func : tuple[np.ndarray, np.ndarray] |
145 | | - (time, stress or ca2) array for one heart beat |
146 | | - type : Literal["stress", "ca2"], optional |
147 | | - type of curve, by default "ca2" |
148 | | - threshold : float, optional |
149 | | - threshold of des/active active stress, by default 0.5e-6. |
150 | | - n : int, optional |
151 | | - No. of heart beat will be written for LS-DYNA, by default 5 |
152 | | -
|
153 | | - Notes |
154 | | - ----- |
155 | | - - If type=='stress', threshold is always 0.5e-6 and ca2+ will be shifted up with 1.0e-6 |
156 | | - except t=0. This ensures a continuous activation during simulation. |
157 | | - """ |
158 | | - self.type = type |
159 | | - self.n_beat = n |
160 | | - |
161 | | - if type == "stress": |
162 | | - LOGGER.warning("Threshold will be reset.") |
163 | | - threshold = 0.5e-6 |
164 | | - self.threshold = threshold |
165 | | - |
166 | | - self.time = func[0] |
167 | | - self.t_beat = self.time[-1] |
168 | | - |
169 | | - if self.type == "ca2": |
170 | | - self.ca2 = func[1] |
| 136 | +class ActiveCurve(BaseModel): |
| 137 | + """Pydantic-backed ActiveCurve.""" |
| 138 | + |
| 139 | + model_config = ConfigDict(arbitrary_types_allowed=True) |
| 140 | + |
| 141 | + func: Tuple[np.ndarray, np.ndarray] = None |
| 142 | + type: Literal["stress", "ca2"] = "ca2" |
| 143 | + threshold: float = 0.5e-6 |
| 144 | + n_beat: int = 5 |
| 145 | + |
| 146 | + # Derived values. exclude these from |
| 147 | + # json serialization. |
| 148 | + time: np.ndarray | None = Field(default=None, exclude=True) |
| 149 | + t_beat: float | None = Field(default=None, exclude=True) |
| 150 | + ca2: np.ndarray | None = Field(default=None, exclude=True) |
| 151 | + stress: np.ndarray | None = Field(default=None, exclude=True) |
| 152 | + |
| 153 | + @field_validator("func", mode="before") |
| 154 | + def _func_validator(cls, v): # noqa: N805 |
| 155 | + """Accept lists/tuples or numpy arrays and return tuple[np.ndarray, np.ndarray].""" |
| 156 | + if v is None: |
| 157 | + raise ValueError("func must be provided as (time, values) arrays") |
| 158 | + |
| 159 | + # Expect a sequence of length 2 |
| 160 | + if not (isinstance(v, (list, tuple)) and len(v) == 2): |
| 161 | + raise ValueError("func must be a tuple/list of (time, values)") |
| 162 | + |
| 163 | + t, y = v |
| 164 | + t_arr = np.asarray(t) |
| 165 | + y_arr = np.asarray(y) |
| 166 | + |
| 167 | + if t_arr.ndim != 1 or y_arr.ndim != 1: |
| 168 | + raise ValueError("func arrays must be 1-dimensional") |
| 169 | + if t_arr.shape != y_arr.shape: |
| 170 | + raise ValueError("func arrays must have the same shape") |
| 171 | + if t_arr.size == 0: |
| 172 | + raise ValueError("func arrays must not be empty") |
| 173 | + if np.any(np.diff(t_arr) <= 0): |
| 174 | + raise ValueError("func time array must be strictly increasing") |
| 175 | + |
| 176 | + return (t_arr, y_arr) |
| 177 | + |
| 178 | + @model_validator(mode="after") |
| 179 | + def _post_init(self): |
| 180 | + # preserve public API names used by callers |
| 181 | + self.time = self.func[0] |
| 182 | + self.t_beat = float(self.time[-1]) |
| 183 | + |
| 184 | + if self.type == "stress": |
| 185 | + # reset threshold as current implementation does |
| 186 | + self.threshold = 0.5e-6 |
| 187 | + self.stress = self.func[1] |
| 188 | + self.ca2 = self._stress_to_ca2(self.stress) |
| 189 | + else: |
| 190 | + self.ca2 = self.func[1] |
171 | 191 | self.stress = None |
172 | | - elif self.type == "stress": |
173 | | - self.stress = func[1] |
174 | | - self.ca2 = self._stress_to_ca2(func[1]) |
175 | 192 |
|
| 193 | + # run the same checks |
176 | 194 | self._check_threshold() |
| 195 | + return self |
| 196 | + |
| 197 | + # optional: serialize numpy arrays to lists for model_dump / JSON |
| 198 | + @field_serializer("func") |
| 199 | + def _serialize_func(self, func: tuple[np.ndarray, np.ndarray], info): |
| 200 | + if isinstance(func[0], np.ndarray) and isinstance(func[1], np.ndarray): |
| 201 | + return (func[0].tolist(), func[1].tolist()) |
| 202 | + else: |
| 203 | + LOGGER.error("Failed to serialize func") |
| 204 | + return None |
177 | 205 |
|
178 | 206 | def _check_threshold(self): |
179 | | - # maybe better to check it cross 1 or 2 times |
180 | 207 | if np.max(self.ca2) < self.threshold or np.min(self.ca2) > self.threshold: |
181 | 208 | raise ValueError("Threshold must cross ca2+ curve at least once") |
182 | 209 |
|
183 | | - @property |
184 | | - def dyna_input(self): |
185 | | - """Return x,y input for k files.""" |
186 | | - return self._repeat((self.time, self.ca2)) |
187 | | - |
188 | | - def plot_time_vs_ca2(self): |
189 | | - """Plot Ca2+ with threshold.""" |
190 | | - fig, ax = plt.subplots(figsize=(8, 4)) |
191 | | - t, v = self._repeat((self.time, self.ca2)) |
192 | | - ax.plot(t, v, label="Ca2+") |
193 | | - ax.hlines(self.threshold, xmin=t[0], xmax=t[-1], label="threshold", colors="red") |
194 | | - ax.set_xlabel("time (ms)") |
195 | | - ax.set_ylabel("Ca2+") |
196 | | - # ax.set_title('Ca2+') |
197 | | - ax.legend() |
198 | | - return fig |
199 | | - |
200 | | - def plot_time_vs_stress(self): |
201 | | - """Plot stress.""" |
202 | | - if self.stress is None: |
203 | | - LOGGER.error("Only support stress curve.") |
204 | | - # self._estimate_stress() |
205 | | - return None |
206 | | - t, v = self._repeat((self.time, self.stress)) |
207 | | - fig, ax = plt.subplots(figsize=(8, 4)) |
208 | | - ax.plot(t, v) |
209 | | - ax.set_xlabel("time (ms)") |
210 | | - ax.set_ylabel("Normalized active stress") |
211 | | - # ax.set_title('Ca2+') |
212 | | - # ax.legend() |
213 | | - return fig |
214 | | - |
215 | | - def _stress_to_ca2(self, stress): |
| 210 | + def _stress_to_ca2(self, stress: np.ndarray) -> np.ndarray: |
216 | 211 | if np.min(stress) < 0 or np.max(stress) > 1.0: |
217 | | - LOGGER.error("Stress curve is not between 0-1.") |
218 | 212 | raise ValueError("Stress curve must be between 0-1.") |
219 | | - |
220 | | - # assuming actype=3, eta=0; n=1; Ca2+50=1 |
221 | 213 | ca2 = 1 / (1 - 0.999 * stress) - 1 |
222 | | - |
223 | | - # offset about threshold |
224 | 214 | ca2[0] = 0.0 |
225 | 215 | ca2[1:] += 2 * self.threshold |
226 | | - |
227 | 216 | return ca2 |
228 | 217 |
|
229 | | - def _repeat(self, curve): |
| 218 | + def _repeat(self, curve: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: |
230 | 219 | t = np.copy(curve[0]) |
231 | 220 | v = np.copy(curve[1]) |
232 | | - |
233 | 221 | for ii in range(1, self.n_beat): |
234 | 222 | t = np.append(t, curve[0][1:] + ii * self.t_beat) |
235 | 223 | v = np.append(v, curve[1][1:]) |
236 | 224 | return (t, v) |
237 | 225 |
|
238 | | - def _estimate_stress(self): |
239 | | - # TODO: only with 1 |
240 | | - # TODO: @wenfengye ensure ruff compatibility, see the noqa's |
241 | | - ca2ionmax = 4.35 |
242 | | - ca2ion = 4.35 |
243 | | - n = 2 |
244 | | - mr = 1048.9 |
245 | | - dtmax = 150 |
246 | | - tr = -1429 |
247 | | - # Range of L 1.78-1.91 |
248 | | - L = 1.85 # noqa N806 |
249 | | - l0 = 1.58 |
250 | | - b = 4.75 |
251 | | - lam = 1 |
252 | | - cf = (np.exp(b * (lam * L - l0)) - 1) ** 0.5 |
253 | | - ca2ion50 = ca2ionmax / cf |
254 | | - dtr = mr * lam * L + tr |
255 | | - self.stress = np.zeros(self.ca2.shape) |
256 | | - for i, t in enumerate(self.time): |
257 | | - if t < dtmax: |
258 | | - w = np.pi * t / dtmax |
259 | | - elif dtmax <= t <= dtmax + dtr: |
260 | | - w = np.pi * (t - dtmax + dtr) / dtr |
261 | | - else: |
262 | | - w = 0 |
263 | | - c = 0.5 * (1 - np.cos(w)) |
264 | | - self.stress[i] = c * ca2ion**n / (ca2ion**n + ca2ion50**n) |
265 | | - |
266 | | - |
267 | | -if __name__ == "__main__": |
268 | | - a = ActiveCurve(constant_ca2(), threshold=0.1, type="ca2") |
269 | | - # a = Ca2Curve(unit_constant_ca2(), type="ca2") |
270 | | - a.plot_time_vs_ca2() |
271 | | - a.plot_time_vs_stress() |
| 226 | + @property |
| 227 | + def dyna_input(self) -> Tuple[np.ndarray, np.ndarray]: |
| 228 | + """Return LS-DYNA input arrays.""" |
| 229 | + return self._repeat((self.time, self.ca2)) |
| 230 | + |
| 231 | + def plot_time_vs_ca2(self): |
| 232 | + """Plot time vs ca2.""" |
| 233 | + import matplotlib.pyplot as plt |
| 234 | + |
| 235 | + t, v = self.dyna_input |
| 236 | + fig, ax = plt.subplots() |
| 237 | + ax.plot(t, v, label="ca2") |
| 238 | + ax.axhline(self.threshold, color="r", linestyle="--", label="threshold") |
| 239 | + ax.set_xlabel("Time (ms)") |
| 240 | + ax.set_ylabel("Ca2+") |
| 241 | + ax.set_title("Active Ca2+ Curve") |
| 242 | + ax.legend() |
| 243 | + return fig |
| 244 | + |
| 245 | + def plot_time_vs_stress(self): |
| 246 | + """Plot time vs stress.""" |
| 247 | + if self.type != "stress": |
| 248 | + raise ValueError("Curve type is not 'stress', cannot plot stress.") |
| 249 | + |
| 250 | + import matplotlib.pyplot as plt |
| 251 | + |
| 252 | + t, v = self._repeat((self.time, self.stress)) |
| 253 | + fig, ax = plt.subplots() |
| 254 | + ax.plot(t, v, label="stress") |
| 255 | + ax.set_xlabel("Time (ms)") |
| 256 | + ax.set_ylabel("Stress (normalized)") |
| 257 | + ax.set_title("Active Stress Curve") |
| 258 | + ax.legend() |
| 259 | + return fig |
0 commit comments