Skip to content

Commit e7afd03

Browse files
refactor: use pydantic-based materials module (#1244)
Co-authored-by: pyansys-ci-bot <[email protected]>
1 parent 8e1ee83 commit e7afd03

File tree

14 files changed

+826
-1155
lines changed

14 files changed

+826
-1155
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Use pydantic-based materials module

examples/preprocessor/demo-material_pr.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import ansys.health.heart.settings.material.cell_models as cell_models
3939
from ansys.health.heart.settings.material.curve import kumaraswamy_active
4040
import ansys.health.heart.settings.material.ep_material as ep_materials
41-
from ansys.health.heart.settings.material.material_pyd import (
41+
from ansys.health.heart.settings.material.material import (
4242
ACTIVE,
4343
ANISO,
4444
ISO,
@@ -88,7 +88,9 @@
8888
# create active model 1
8989
ac_model1 = ActiveModel1()
9090
# create Ca2+ curve
91-
ac_curve1 = ActiveCurve(constant_ca2(tb=800, ca2ionm=ac_model1.ca2ionm), type="ca2", threshold=0.5)
91+
ac_curve1 = ActiveCurve(
92+
func=constant_ca2(tb=800, ca2ionm=ac_model1.ca2ionm), type="ca2", threshold=0.5
93+
)
9294
# build active module
9395
active = ACTIVE(model=ac_model1, ca2_curve=ac_curve1)
9496

@@ -106,7 +108,7 @@
106108
# create active model 3
107109
ac_model3 = ActiveModel3()
108110
# create a stress curve and show
109-
ac_curve3 = ActiveCurve(kumaraswamy_active(t_end=800), type="stress")
111+
ac_curve3 = ActiveCurve(func=kumaraswamy_active(t_end=800), type="stress")
110112
fig = ac_curve3.plot_time_vs_stress()
111113
plt.show()
112114

@@ -186,7 +188,7 @@
186188
heartmodel.left_ventricle.ep_material = ep_mat_active
187189

188190
# Print it. You should see the following:
189-
# MAT295(rho=1, iso=ISO(itype=-3, beta=0.0, nu=0.499, k1=1, k2=1), aopt=2.0, aniso=ANISO(atype=-1, fibers=[ANISO.HGOFiber(k1=1, k2=1, a=0.0, b=1.0, _theta=0.0, _ftype=1, _fcid=0)], k1fs=None, k2fs=None, vec_a=(1.0, 0.0, 0.0), vec_d=(0.0, 1.0, 0.0), nf=1, intype=0), active=ActiveModel.Model1(t0=None, ca2ion=None, ca2ionm=4.35, n=2, taumax=0.125, stf=0.0, b=4.75, l0=1.58, l=1.85, dtmax=150, mr=1048.9, tr=-1629.0)) # noqa
191+
# MAT295(rho=1, iso=ISO(itype=-3, beta=0.0, nu=0.499, k1=1, k2=1), aopt=2.0, aniso=ANISO(atype=-1, fibers=[HGOFiber(k1=1, k2=1, a=0.0, b=1.0, _theta=0.0, _ftype=1, _fcid=0)], k1fs=None, k2fs=None, vec_a=(1.0, 0.0, 0.0), vec_d=(0.0, 1.0, 0.0), nf=1, intype=0), active=ActiveModel.Model1(t0=None, ca2ion=None, ca2ionm=4.35, n=2, taumax=0.125, stf=0.0, b=4.75, l0=1.58, l=1.85, dtmax=150, mr=1048.9, tr=-1629.0)) # noqa
190192
print(heartmodel.left_ventricle.meca_material)
191193
print(heartmodel.left_ventricle.ep_material)
192194
###############################################################################

examples/simulator/mechanics-simulator-leftventricle_pr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from ansys.health.heart.examples import get_input_leftventricle
7575
import ansys.health.heart.models as models
7676
from ansys.health.heart.post.dpf_utils import ICVoutReader
77-
from ansys.health.heart.settings.material.material import ACTIVE, ANISO, ISO, Mat295
77+
from ansys.health.heart.settings.material.material import ACTIVE, ANISO, ISO, HGOFiber, Mat295
7878
from ansys.health.heart.simulator import DynaSettings, MechanicsSimulator
7979

8080
###############################################################################
@@ -187,7 +187,7 @@
187187
iso=ISO(itype=-3, beta=2, kappa=1.0, k1=0.20e-3, k2=6.55),
188188
aniso=ANISO(
189189
atype=-1,
190-
fibers=[ANISO.HGOFiber(k1=0.00305, k2=29.05), ANISO.HGOFiber(k1=1.25e-3, k2=36.65)],
190+
fibers=[HGOFiber(k1=0.00305, k2=29.05), HGOFiber(k1=1.25e-3, k2=36.65)],
191191
k1fs=0.15e-3,
192192
k2fs=6.28,
193193
),

src/ansys/health/heart/parts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self, name: str = None, part_type: _PartType = _PartType.UNDEFINED)
137137
self.active: bool = False
138138
"""Flag indicating if active stress is established."""
139139

140-
self.meca_material: MechanicalMaterialModel = MechanicalMaterialModel.DummyMaterial()
140+
self.meca_material: MechanicalMaterialModel = None
141141
"""Material model to assign in the simulator."""
142142

143143
self.ep_material: ep_materials.EPMaterialModel = None

src/ansys/health/heart/settings/material/curve.py

Lines changed: 111 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@
2222

2323
"""Module for active stress curve."""
2424

25-
from typing import Literal
25+
from typing import Literal, Tuple
2626

27-
import matplotlib.pyplot as plt
2827
import numpy as np
28+
from pydantic import (
29+
BaseModel,
30+
ConfigDict,
31+
Field,
32+
field_serializer,
33+
field_validator,
34+
model_validator,
35+
)
2936

3037
from ansys.health.heart import LOG as LOGGER
3138

@@ -126,146 +133,127 @@ def constant_ca2(tb: float = 800, ca2ionm: float = 4.35) -> tuple[np.ndarray, np
126133
return (t, v)
127134

128135

129-
# TODO: Use pydantic to easily (de)serialize the curve.
130-
class ActiveCurve:
131-
"""Active stress or Ca2+ curve."""
132-
133-
def __init__(
134-
self,
135-
func: tuple[np.ndarray, np.ndarray],
136-
type: Literal["stress", "ca2"] = "ca2",
137-
threshold: float = 0.5e-6,
138-
n: int = 5,
139-
) -> None:
140-
"""Define a curve for active behavior of MAT295.
141-
142-
Parameters
143-
----------
144-
func : tuple[np.ndarray, np.ndarray]
145-
(time, stress or ca2) array for one heart beat
146-
type : Literal[&quot;stress&quot;, &quot;ca2&quot;], optional
147-
type of curve, by default "ca2"
148-
threshold : float, optional
149-
threshold of des/active active stress, by default 0.5e-6.
150-
n : int, optional
151-
No. of heart beat will be written for LS-DYNA, by default 5
152-
153-
Notes
154-
-----
155-
- If type=='stress', threshold is always 0.5e-6 and ca2+ will be shifted up with 1.0e-6
156-
except t=0. This ensures a continuous activation during simulation.
157-
"""
158-
self.type = type
159-
self.n_beat = n
160-
161-
if type == "stress":
162-
LOGGER.warning("Threshold will be reset.")
163-
threshold = 0.5e-6
164-
self.threshold = threshold
165-
166-
self.time = func[0]
167-
self.t_beat = self.time[-1]
168-
169-
if self.type == "ca2":
170-
self.ca2 = func[1]
136+
class ActiveCurve(BaseModel):
137+
"""Pydantic-backed ActiveCurve."""
138+
139+
model_config = ConfigDict(arbitrary_types_allowed=True)
140+
141+
func: Tuple[np.ndarray, np.ndarray] = None
142+
type: Literal["stress", "ca2"] = "ca2"
143+
threshold: float = 0.5e-6
144+
n_beat: int = 5
145+
146+
# Derived values. exclude these from
147+
# json serialization.
148+
time: np.ndarray | None = Field(default=None, exclude=True)
149+
t_beat: float | None = Field(default=None, exclude=True)
150+
ca2: np.ndarray | None = Field(default=None, exclude=True)
151+
stress: np.ndarray | None = Field(default=None, exclude=True)
152+
153+
@field_validator("func", mode="before")
154+
def _func_validator(cls, v): # noqa: N805
155+
"""Accept lists/tuples or numpy arrays and return tuple[np.ndarray, np.ndarray]."""
156+
if v is None:
157+
raise ValueError("func must be provided as (time, values) arrays")
158+
159+
# Expect a sequence of length 2
160+
if not (isinstance(v, (list, tuple)) and len(v) == 2):
161+
raise ValueError("func must be a tuple/list of (time, values)")
162+
163+
t, y = v
164+
t_arr = np.asarray(t)
165+
y_arr = np.asarray(y)
166+
167+
if t_arr.ndim != 1 or y_arr.ndim != 1:
168+
raise ValueError("func arrays must be 1-dimensional")
169+
if t_arr.shape != y_arr.shape:
170+
raise ValueError("func arrays must have the same shape")
171+
if t_arr.size == 0:
172+
raise ValueError("func arrays must not be empty")
173+
if np.any(np.diff(t_arr) <= 0):
174+
raise ValueError("func time array must be strictly increasing")
175+
176+
return (t_arr, y_arr)
177+
178+
@model_validator(mode="after")
179+
def _post_init(self):
180+
# preserve public API names used by callers
181+
self.time = self.func[0]
182+
self.t_beat = float(self.time[-1])
183+
184+
if self.type == "stress":
185+
# reset threshold as current implementation does
186+
self.threshold = 0.5e-6
187+
self.stress = self.func[1]
188+
self.ca2 = self._stress_to_ca2(self.stress)
189+
else:
190+
self.ca2 = self.func[1]
171191
self.stress = None
172-
elif self.type == "stress":
173-
self.stress = func[1]
174-
self.ca2 = self._stress_to_ca2(func[1])
175192

193+
# run the same checks
176194
self._check_threshold()
195+
return self
196+
197+
# optional: serialize numpy arrays to lists for model_dump / JSON
198+
@field_serializer("func")
199+
def _serialize_func(self, func: tuple[np.ndarray, np.ndarray], info):
200+
if isinstance(func[0], np.ndarray) and isinstance(func[1], np.ndarray):
201+
return (func[0].tolist(), func[1].tolist())
202+
else:
203+
LOGGER.error("Failed to serialize func")
204+
return None
177205

178206
def _check_threshold(self):
179-
# maybe better to check it cross 1 or 2 times
180207
if np.max(self.ca2) < self.threshold or np.min(self.ca2) > self.threshold:
181208
raise ValueError("Threshold must cross ca2+ curve at least once")
182209

183-
@property
184-
def dyna_input(self):
185-
"""Return x,y input for k files."""
186-
return self._repeat((self.time, self.ca2))
187-
188-
def plot_time_vs_ca2(self):
189-
"""Plot Ca2+ with threshold."""
190-
fig, ax = plt.subplots(figsize=(8, 4))
191-
t, v = self._repeat((self.time, self.ca2))
192-
ax.plot(t, v, label="Ca2+")
193-
ax.hlines(self.threshold, xmin=t[0], xmax=t[-1], label="threshold", colors="red")
194-
ax.set_xlabel("time (ms)")
195-
ax.set_ylabel("Ca2+")
196-
# ax.set_title('Ca2+')
197-
ax.legend()
198-
return fig
199-
200-
def plot_time_vs_stress(self):
201-
"""Plot stress."""
202-
if self.stress is None:
203-
LOGGER.error("Only support stress curve.")
204-
# self._estimate_stress()
205-
return None
206-
t, v = self._repeat((self.time, self.stress))
207-
fig, ax = plt.subplots(figsize=(8, 4))
208-
ax.plot(t, v)
209-
ax.set_xlabel("time (ms)")
210-
ax.set_ylabel("Normalized active stress")
211-
# ax.set_title('Ca2+')
212-
# ax.legend()
213-
return fig
214-
215-
def _stress_to_ca2(self, stress):
210+
def _stress_to_ca2(self, stress: np.ndarray) -> np.ndarray:
216211
if np.min(stress) < 0 or np.max(stress) > 1.0:
217-
LOGGER.error("Stress curve is not between 0-1.")
218212
raise ValueError("Stress curve must be between 0-1.")
219-
220-
# assuming actype=3, eta=0; n=1; Ca2+50=1
221213
ca2 = 1 / (1 - 0.999 * stress) - 1
222-
223-
# offset about threshold
224214
ca2[0] = 0.0
225215
ca2[1:] += 2 * self.threshold
226-
227216
return ca2
228217

229-
def _repeat(self, curve):
218+
def _repeat(self, curve: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
230219
t = np.copy(curve[0])
231220
v = np.copy(curve[1])
232-
233221
for ii in range(1, self.n_beat):
234222
t = np.append(t, curve[0][1:] + ii * self.t_beat)
235223
v = np.append(v, curve[1][1:])
236224
return (t, v)
237225

238-
def _estimate_stress(self):
239-
# TODO: only with 1
240-
# TODO: @wenfengye ensure ruff compatibility, see the noqa's
241-
ca2ionmax = 4.35
242-
ca2ion = 4.35
243-
n = 2
244-
mr = 1048.9
245-
dtmax = 150
246-
tr = -1429
247-
# Range of L 1.78-1.91
248-
L = 1.85 # noqa N806
249-
l0 = 1.58
250-
b = 4.75
251-
lam = 1
252-
cf = (np.exp(b * (lam * L - l0)) - 1) ** 0.5
253-
ca2ion50 = ca2ionmax / cf
254-
dtr = mr * lam * L + tr
255-
self.stress = np.zeros(self.ca2.shape)
256-
for i, t in enumerate(self.time):
257-
if t < dtmax:
258-
w = np.pi * t / dtmax
259-
elif dtmax <= t <= dtmax + dtr:
260-
w = np.pi * (t - dtmax + dtr) / dtr
261-
else:
262-
w = 0
263-
c = 0.5 * (1 - np.cos(w))
264-
self.stress[i] = c * ca2ion**n / (ca2ion**n + ca2ion50**n)
265-
266-
267-
if __name__ == "__main__":
268-
a = ActiveCurve(constant_ca2(), threshold=0.1, type="ca2")
269-
# a = Ca2Curve(unit_constant_ca2(), type="ca2")
270-
a.plot_time_vs_ca2()
271-
a.plot_time_vs_stress()
226+
@property
227+
def dyna_input(self) -> Tuple[np.ndarray, np.ndarray]:
228+
"""Return LS-DYNA input arrays."""
229+
return self._repeat((self.time, self.ca2))
230+
231+
def plot_time_vs_ca2(self):
232+
"""Plot time vs ca2."""
233+
import matplotlib.pyplot as plt
234+
235+
t, v = self.dyna_input
236+
fig, ax = plt.subplots()
237+
ax.plot(t, v, label="ca2")
238+
ax.axhline(self.threshold, color="r", linestyle="--", label="threshold")
239+
ax.set_xlabel("Time (ms)")
240+
ax.set_ylabel("Ca2+")
241+
ax.set_title("Active Ca2+ Curve")
242+
ax.legend()
243+
return fig
244+
245+
def plot_time_vs_stress(self):
246+
"""Plot time vs stress."""
247+
if self.type != "stress":
248+
raise ValueError("Curve type is not 'stress', cannot plot stress.")
249+
250+
import matplotlib.pyplot as plt
251+
252+
t, v = self._repeat((self.time, self.stress))
253+
fig, ax = plt.subplots()
254+
ax.plot(t, v, label="stress")
255+
ax.set_xlabel("Time (ms)")
256+
ax.set_ylabel("Stress (normalized)")
257+
ax.set_title("Active Stress Curve")
258+
ax.legend()
259+
return fig

0 commit comments

Comments
 (0)