diff --git a/doc/source/changelog/1259.miscellaneous.md b/doc/source/changelog/1259.miscellaneous.md new file mode 100644 index 000000000..9cc794ba4 --- /dev/null +++ b/doc/source/changelog/1259.miscellaneous.md @@ -0,0 +1 @@ +Use pydantic for the settings module diff --git a/src/ansys/health/heart/post/system_model_post.py b/src/ansys/health/heart/post/system_model_post.py index 532551e39..78e00b704 100644 --- a/src/ansys/health/heart/post/system_model_post.py +++ b/src/ansys/health/heart/post/system_model_post.py @@ -181,15 +181,15 @@ def __init__(self, dir: str): s = SimulationSettings() s.load(os.path.join(self.dir, "simulation_settings.yml")) l_ed_pressure = ( - s.mechanics.boundary_conditions.end_diastolic_cavity_pressure.left_ventricle.to( - "kilopascal" - ).m + s.mechanics.boundary_conditions.end_diastolic_cavity_pressure.get("left_ventricle") + .to("kilopascal") + .m ) if self.model_type == "BV": r_ed_pressure = ( - s.mechanics.boundary_conditions.end_diastolic_cavity_pressure.right_ventricle.to( - "kilopascal" - ).m + s.mechanics.boundary_conditions.end_diastolic_cavity_pressure.get("right_ventricle") + .to("kilopascal") + .m ) # get EOD volume diff --git a/src/ansys/health/heart/settings/settings.py b/src/ansys/health/heart/settings/settings.py index 85bb60c33..9499083e0 100644 --- a/src/ansys/health/heart/settings/settings.py +++ b/src/ansys/health/heart/settings/settings.py @@ -20,18 +20,42 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -"""Module that defines classes that hold settings relevant for PyAnsys-Heart.""" +"""Module that defines classes that hold settings relevant for PyAnsys-Heart. + +Examples +-------- +Create and configure simulation settings: + +>>> from ansys.health.heart.settings.settings import SimulationSettings +>>> settings = SimulationSettings() +>>> settings.load_defaults() +>>> settings.mechanics.analysis.end_time = Quantity(1000, "ms") +>>> settings.save("config.yml") + +Load existing configuration: + +>>> settings = SimulationSettings() +>>> settings.load("config.yml") +>>> print(settings.mechanics.analysis.end_time) +1000.0 millisecond +""" -import copy -from dataclasses import asdict, dataclass, field import json import os import pathlib from pathlib import Path import shutil -from typing import List, Literal +from typing import Any, Literal from pint import Quantity, UnitRegistry +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_serializer, + field_validator, +) import yaml from ansys.health.heart import LOG as LOGGER @@ -48,311 +72,720 @@ zeropressure as zero_pressure_defaults, ) +ureg = UnitRegistry() -class AttrDict(dict): - """Dictionary subclass whose entries can be accessed by attributes as well as normally.""" - def __init__(self, *args, **kwargs): - """Construct nested AttrDicts from nested dictionaries.""" +class BaseSettings(BaseModel): + """Base class for all settings with Pydantic validation and serialization. - def from_nested_dict(data): - """Construct nested AttrDicts from nested dictionaries.""" - if not isinstance(data, dict): - return data - else: - return AttrDict({key: from_nested_dict(data[key]) for key in data}) + Features + -------- + - Automatic validation of types and values + - Built-in JSON/YAML serialization with Pint Quantity support + - Unit conversion to consistent unit system ["MPa", "mm", "N", "ms", "g"] + - Nested model validation and type safety - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - for key in self.keys(): - self[key] = from_nested_dict(self[key]) + Examples + -------- + >>> settings = BaseSettings() + >>> settings.to_consistent_unit_system() + >>> data = settings.model_dump_json() + """ + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + extra="forbid", + str_strip_whitespace=True, + ) -class Settings: - """Generic settings class.""" + def __repr__(self) -> str: + """Represent object in YAML-style format using Pydantic v2 serialization. - def __repr__(self): - """Represent object in dictionary in YAML style.""" - d = self.serialize() - d = {self.__class__.__name__: d} - return yaml.dump(json.loads(json.dumps(d)), sort_keys=False) - - def set_values(self, defaults: dict): - """Read default settings from dictionary.""" - for key, value in self.__dict__.items(): - if key in defaults.keys(): - # set as AttrDict - if isinstance(defaults[key], dict): - setattr(self, key, AttrDict(defaults[key])) - else: - setattr(self, key, defaults[key]) + Returns + ------- + str + YAML-formatted string representation of the object. + """ + data = self.model_dump(mode="json", exclude_none=True) + data = {self.__class__.__name__: data} + return yaml.dump(json.loads(json.dumps(data)), sort_keys=False) - def serialize(self, remove_units: bool = False) -> dict: - """Serialize the settings, that is formats the Quantity as str( ).""" - dictionary = copy.deepcopy(asdict(self)) - _serialize_quantity(dictionary, remove_units) - return dictionary + @field_serializer("*", when_used="json") + def serialize_quantities_for_json(self, value: Any, _info) -> str | float | Any: + """Serialize Quantity objects for JSON output. - def to_consistent_unit_system(self): - """Convert units to a consistent unit system. + This serializer handles Quantity objects during JSON serialization, + providing string representation. + Handles nested Quantity objects in dictionaries and lists. + + Parameters + ---------- + value : Any + The field value to serialize. + _info : SerializationInfo + Pydantic serialization context (unused but required by signature). - Notes - ----- - Currently the only supported unit system is ["MPa", "mm", "N", "ms", "g"] - For instance: - Quantity(10, "mm/s") --> Quantity(0.01, "mm/ms") + Returns + ------- + str | float | Any + String representation if Quantity, otherwise unchanged. """ - def _to_consitent_units(d): - """Convert units to a consistent unit system.""" - if isinstance(d, Settings): - d = d.__dict__ - for k, v in d.items(): - if isinstance(v, (dict, AttrDict, Settings)): - _to_consitent_units(v) - elif isinstance(v, Quantity) and not (v.unitless): - # print(f"key: {k} | units {v.units}") - if "[substance]" in list(v.dimensionality): + def _serialize_recursive(obj: Any) -> Any: + """Recursively serialize Quantity objects in nested structures.""" + if isinstance(obj, Quantity): + return str(obj) + elif isinstance(obj, dict): + return {key: _serialize_recursive(val) for key, val in obj.items()} + elif isinstance(obj, (list, tuple)): + return [_serialize_recursive(item) for item in obj] + return obj + + return _serialize_recursive(value) + + @field_validator("*", mode="before") + def parse_quantity(cls, v, info): # noqa D102 + """Parse string values to Quantity objects for fields annotated as Quantity. + + This validator applies to all fields and attempts to parse string values + as Quantity objects when the field is annotated with Quantity type. + For nested models, it ensures proper Quantity parsing across all levels. + + Parameters + ---------- + v : Any + The value to validate and potentially convert to a Quantity. + info : ValidationInfo + Pydantic validation context containing field information. + + Returns + ------- + Any + Quantity object if conversion successful, otherwise the original value. + """ + # If it's already a Quantity, return as-is + if isinstance(v, Quantity): + return v + + # Only attempt parsing for string values + if not isinstance(v, str): + return v + + # Get field annotation if available + field_name = getattr(info, "field_name", None) + if not field_name: + return v + + # Check if this field should be a Quantity based on annotation + if hasattr(cls, "__annotations__") and field_name in cls.__annotations__: + field_annotation = cls.__annotations__[field_name] + + # Check if the field is annotated as Quantity + if field_annotation == Quantity: + try: + return ureg(v) + except Exception as e: + LOGGER.warning( + f"Failed to parse quantity from string '{v}' for field '{field_name}': {e}" + ) + return v + + # Handle generic aliases (for Python 3.9+ compatibility) + if hasattr(field_annotation, "__origin__") and field_annotation.__origin__ == Quantity: + try: + return ureg(v) + except Exception as e: + LOGGER.warning( + f"Failed to parse quantity from string '{v}' for field '{field_name}': {e}" + ) + return v + + # For non-Quantity fields, pass through unchanged + return v + + def to_consistent_unit_system(self) -> None: + """Convert units to consistent system ["MPa", "mm", "N", "ms", "g"]. + + This method converts all Quantity objects to use the PyAnsys Heart + standard unit system for cardiac simulations. + + Examples + -------- + >>> from pint import Quantity + >>> settings = BaseSettings() + >>> # Assuming a Quantity field exists + >>> settings.to_consistent_unit_system() + """ + + def _to_consistent_units(obj: Any) -> None: + """Convert units recursively.""" + if isinstance(obj, BaseSettings): + obj_dict = obj.__dict__ + elif isinstance(obj, dict): + obj_dict = obj + else: + return + + for key, value in obj_dict.items(): + if isinstance(value, (dict, BaseSettings)): + _to_consistent_units(value) + elif isinstance(value, Quantity) and not value.unitless: + if "[substance]" in list(value.dimensionality): LOGGER.warning("Not converting [substance] / [length]^3") continue - d.update({k: v.to(_get_consistent_units_str(v.dimensionality))}) - return + new_quantity = value.to(_get_consistent_units_str(value.dimensionality)) + if isinstance(obj, BaseSettings): + setattr(obj, key, new_quantity) + else: + obj[key] = new_quantity + + _to_consistent_units(self) + + +class Analysis(BaseSettings): + """Class for analysis settings. + + Defines core simulation analysis parameters including time stepping, + output intervals, and damping parameters for cardiac simulations. + + Attributes + ---------- + end_time : Quantity + End time of simulation in time units. + dtmin : Quantity + Minimum time-step of simulation in time units. + dtmax : Quantity + Maximum time-step of simulation in time units. + dt_d3plot : Quantity + Time-step of d3plot export in time units. + dt_icvout : Quantity + Time-step of icvout export in time units. + global_damping : Quantity + Global damping constant in 1/time units. + stiffness_damping : Quantity + Stiffness damping constant in time units. + + Examples + -------- + >>> from pint import Quantity + >>> analysis = Analysis( + ... end_time=Quantity(1000, "ms"), dtmin=Quantity(0.1, "ms"), dtmax=Quantity(10, "ms") + ... ) + >>> analysis.to_consistent_unit_system() + >>> print(analysis.end_time) + 1000.0 millisecond + """ + + end_time: Quantity = Field(default=Quantity(0, "s"), description="End time of simulation") + dtmin: Quantity = Field(default=Quantity(0, "s"), description="Minimum time-step of simulation") + dtmax: Quantity = Field(default=Quantity(0, "s"), description="Maximum time-step of simulation") + dt_d3plot: Quantity = Field(default=Quantity(0, "s"), description="Time-step of d3plot export") + dt_icvout: Quantity = Field(default=Quantity(0, "s"), description="Time-step of icvout export") + global_damping: Quantity = Field( + default=Quantity(0, "1/s"), description="Global damping constant" + ) + stiffness_damping: Quantity = Field( + default=Quantity(0, "s"), description="Stiffness damping constant" + ) - _to_consitent_units(self) - return - def _remove_units(self): - """Remove all units from Quantity objects.""" - - def __remove_units(d): - units = [] - if isinstance(d, Settings): - d = d.__dict__ - for k, v in d.items(): - if isinstance(v, (dict, AttrDict, Settings)): - units += __remove_units(v) - elif isinstance(v, Quantity): - # LOGGER.debug(f"key: {k} | units {v.units}") - units.append(v.units) - d.update({k: v.m}) - return units - - removed_units = __remove_units(self) - return removed_units - - -@dataclass(repr=False) -class Analysis(Settings): - """Class for analysis settings.""" - - end_time: Quantity = Quantity(0, "s") - """End time of simulation.""" - dtmin: Quantity = Quantity(0, "s") - """Minimum time-step of simulation.""" - dtmax: Quantity = Quantity(0, "s") - """Maximum time-step of simulation.""" - dt_d3plot: Quantity = Quantity(0, "s") - """Time-step of d3plot export.""" - dt_icvout: Quantity = Quantity(0, "s") - """Time-step of icvout export.""" - global_damping: Quantity = Quantity(0, "1/s") - """Global damping constant.""" - stiffness_damping: Quantity = Quantity(0, "s") - """Stiffness damping constant.""" - - -@dataclass(repr=False) class EPAnalysis(Analysis): - """Class for EP analysis settings.""" + """Class for EP analysis settings. - solvertype: Literal["Monodomain", "Eikonal", "ReactionEikonal"] = "Monodomain" + Extends Analysis with electrophysiology-specific solver configuration. + Supports different EP solver types for cardiac electrical simulation. + Attributes + ---------- + solvertype : Literal["Monodomain", "Eikonal", "ReactionEikonal"] + Type of electrophysiology solver to use. -@dataclass(repr=False) -class BoundaryConditions(Settings): - """Stores settings/parameters for boundary conditions.""" + Examples + -------- + >>> ep_analysis = EPAnalysis(solvertype="Monodomain") + >>> print(ep_analysis.solvertype) + Monodomain + """ - robin: AttrDict = None - """Parameters for pericardium spring/damper b.c.""" - valve: AttrDict = None - """Parameters for valve spring b.c.""" - end_diastolic_cavity_pressure: AttrDict = None - """End-diastolic pressure.""" + solvertype: Literal["Monodomain", "Eikonal", "ReactionEikonal"] = Field( + default="Monodomain", description="Type of electrophysiology solver" + ) -@dataclass(repr=False) -class SystemModel(Settings): - """Stores settings/parameters for system model.""" +class BoundaryConditions(BaseSettings): + """Stores settings/parameters for boundary conditions. + + Manages boundary condition parameters for cardiac simulation including + pericardium constraints, valve mechanics, and pressure loading. + + Attributes + ---------- + robin : dict[str, Any] | None + Parameters for pericardium spring/damper boundary conditions. + valve : dict[str, Any] | None + Parameters for valve spring boundary conditions. + end_diastolic_cavity_pressure : dict[str, Any] | None + End-diastolic pressure configuration. + + Examples + -------- + >>> bc = BoundaryConditions( + ... robin={"stiffness": 1.0, "damping": 0.1}, valve={"spring_constant": 100.0} + ... ) + >>> print(bc.robin["stiffness"]) + 1.0 + """ - name: str = "ConstantPreloadWindkesselAfterload" - """Name of the system model.""" + robin: dict[str, Any] | None = Field( + default=None, description="Parameters for pericardium spring/damper b.c." + ) + valve: dict[str, Any] | None = Field( + default=None, description="Parameters for valve spring b.c." + ) + end_diastolic_cavity_pressure: dict[str, Any] | None = Field( + default=None, description="End-diastolic pressure" + ) - left_ventricle: AttrDict = None - """Parameters for the left ventricle.""" - right_ventricle: AttrDict = None - """Parameters for the right ventricle.""" +class SystemModel(BaseSettings): + """Stores settings/parameters for system model. + + Manages system-level model configuration including circulatory system + models and ventricular-specific parameters. + + Attributes + ---------- + name : str + Name of the system model implementation. + left_ventricle : dict[str, Any] | None + Parameters specific to left ventricle modeling. + right_ventricle : dict[str, Any] | None + Parameters specific to right ventricle modeling. + + Examples + -------- + >>> system = SystemModel( + ... name="ConstantPreloadWindkesselAfterload", + ... left_ventricle={"volume": 150.0}, + ... right_ventricle={"volume": 120.0}, + ... ) + >>> print(system.name) + ConstantPreloadWindkesselAfterload + """ -@dataclass(repr=False) -class Mechanics(Settings): - """Class for keeping track of settings.""" + name: str = Field( + default="ConstantPreloadWindkesselAfterload", description="Name of the system model" + ) + left_ventricle: dict[str, Any] | None = Field( + default=None, description="Parameters for the left ventricle" + ) + right_ventricle: dict[str, Any] | None = Field( + default=None, description="Parameters for the right ventricle" + ) - analysis: Analysis = field(default_factory=lambda: Analysis()) - """Generic analysis settings.""" - boundary_conditions: BoundaryConditions = field(default_factory=lambda: BoundaryConditions()) - """Boundary condition specifications.""" - system: SystemModel = field(default_factory=lambda: SystemModel()) - """System model settings.""" + +class Mechanics(BaseSettings): + """Class for keeping track of mechanical simulation settings. + + Complete mechanical simulation configuration including analysis parameters, + boundary conditions, and system model settings. + + Attributes + ---------- + analysis : Analysis + Generic analysis settings for time stepping and output. + boundary_conditions : BoundaryConditions + Boundary condition specifications and parameters. + system : SystemModel + System model settings and configurations. + + Examples + -------- + >>> mechanics = Mechanics() + >>> mechanics.analysis.end_time = Quantity(1000, "ms") + >>> mechanics.boundary_conditions.robin = {"stiffness": 1.0} + >>> print(mechanics.analysis.end_time) + 1000.0 millisecond + """ + + analysis: Analysis = Field(default_factory=Analysis, description="Generic analysis settings") + boundary_conditions: BoundaryConditions = Field( + default_factory=BoundaryConditions, description="Boundary condition specifications" + ) + system: SystemModel = Field(default_factory=SystemModel, description="System model settings") -@dataclass(repr=False) class AnalysisZeroPressure(Analysis): - """Class for keeping track of zero-pressure analysis settings.""" + """Class for keeping track of zero-pressure analysis settings. + + Extends Analysis with specific settings for stress-free configuration + computation, including iterative solver parameters. + + Attributes + ---------- + dt_nodout : Quantity + Time interval of nodeout export. + max_iters : int + Maximum iterations for stress-free-configuration algorithm. + method : int + Method identifier to use for computation. + tolerance : float + Tolerance for iterative algorithm convergence. + + Examples + -------- + >>> zero_p = AnalysisZeroPressure(max_iters=5, tolerance=1.0, method=2) + >>> print(zero_p.max_iters) + 5 + """ - dt_nodout: Quantity = 0 - """Time interval of nodeout export.""" + dt_nodout: Quantity = Field( + default=Quantity(0, "s"), description="Time interval of nodeout export" + ) + max_iters: int = Field( + default=3, description="Maximum iterations for stress-free-configuration algorithm" + ) + method: int = Field(default=2, description="Method to use") + tolerance: float = Field(default=5.0, description="Tolerance to use for iterative algorithm") - max_iters: int = 3 - """Maximum iterations for stress-free-configuration algorithm.""" - method: int = 2 - """Method to use.""" - # TODO: this should be a Quantity type - tolerance: float = 5.0 - """Tolerance to use for iterative algorithm.""" +class ZeroPressure(BaseSettings): + """Class for keeping track of settings for stress-free-configuration computation. -@dataclass(repr=False) -class ZeroPressure(Settings): - """Class for keeping track of settings for stress-free-configuration computation.""" + Configuration for computing the stress-free (unloaded) configuration + of cardiac geometry, essential for accurate mechanical simulations. - analysis: AnalysisZeroPressure = field(default_factory=lambda: AnalysisZeroPressure()) - """Generic analysis settings.""" + Attributes + ---------- + analysis : AnalysisZeroPressure + Analysis settings specific to zero-pressure computation. + Examples + -------- + >>> zero_pressure = ZeroPressure() + >>> zero_pressure.analysis.max_iters = 5 + >>> zero_pressure.analysis.tolerance = 1.0 + >>> print(zero_pressure.analysis.method) + 2 + """ -@dataclass -class Stimulation(Settings): - """Stimulation settings.""" + analysis: AnalysisZeroPressure = Field( + default_factory=AnalysisZeroPressure, description="Generic analysis settings" + ) - node_ids: List[int] = None - t_start: Quantity = Quantity(0.0, "ms") - period: Quantity = Quantity(800, "ms") - duration: Quantity = Quantity(2, "ms") - amplitude: Quantity = Quantity(50, "uF/mm^3") - def __setattr__(self, __name: str, __value) -> None: - """Set attributes. +class Stimulation(BaseSettings): + """Stimulation settings for electrophysiology simulations. + + Defines electrical stimulation parameters including timing, location, + and amplitude for cardiac electrophysiology simulations. + + Attributes + ---------- + node_ids : list[int] | None + List of node IDs where stimulation is applied. + t_start : Quantity + Start time of stimulation. + period : Quantity + Period between stimulation cycles. + duration : Quantity + Duration of each stimulation pulse. + amplitude : Quantity + Stimulation amplitude. + + Examples + -------- + >>> from pint import Quantity + >>> stim = Stimulation( + ... node_ids=[1, 2, 3], + ... t_start=Quantity(10, "ms"), + ... period=Quantity(800, "ms"), + ... duration=Quantity(2, "ms"), + ... amplitude=Quantity(50, "uF/mm^3"), + ... ) + >>> print(stim.node_ids) + [1, 2, 3] + """ - Parameters - ---------- - __name : str - Attribute name. - __value : _type_ - Attribute value. + node_ids: list[int] | None = Field(default=None, description="Node IDs for stimulation") + t_start: Quantity = Field(default=Quantity(0.0, "ms"), description="Start time of stimulation") + period: Quantity = Field(default=Quantity(800, "ms"), description="Period between cycles") + duration: Quantity = Field(default=Quantity(2, "ms"), description="Duration of pulse") + amplitude: Quantity = Field( + default=Quantity(50, "uF/mm^3"), description="Stimulation amplitude" + ) - """ - if __name == "node_ids": - if isinstance(__value, list): - try: - __value = [int(x) for x in __value] - except ValueError: - print("Failed to cast node_ids to list of integers.") + @field_validator("node_ids") + @classmethod + def validate_node_ids(cls, v: Any) -> list[int] | None: + """Validate and convert node_ids to list of integers. - return super().__setattr__(__name, __value) - elif __name == "t_start" or __name == "period" or __name == "duration": - return super().__setattr__(__name, Quantity(__value, "ms")) - elif __name == "amplitude": - return super().__setattr__(__name, Quantity(__value, "uF/mm^3")) + Parameters + ---------- + v : Any + Input value to validate. + Returns + ------- + list[int] | None + Validated list of integer node IDs or None. -@dataclass(repr=False) -class Electrophysiology(Settings): - """Class for keeping track of EP settings.""" + Raises + ------ + ValueError + If node_ids cannot be converted to list of integers. + """ + if v is None: + return None + if isinstance(v, list): + try: + return [int(x) for x in v] + except (ValueError, TypeError) as e: + raise ValueError("Failed to cast node_ids to list of integers") from e + raise ValueError("node_ids must be a list of integers or None") + + +class Electrophysiology(BaseSettings): + """Class for keeping track of EP settings. + + Complete electrophysiology simulation configuration including analysis settings, + stimulation protocols, layer definitions, and conductivity parameters. + + Attributes + ---------- + analysis : EPAnalysis + Generic analysis settings for EP simulation. + stimulation : dict[str, Stimulation] | None + Dictionary of stimulation settings by name. + layers : dict[str, Quantity] + Layer definitions for material assignment of myocardium. + lambda_ratio : Quantity + Intra to extracellular conductivity ratio for EP solve. + + Examples + -------- + >>> ep = Electrophysiology() + >>> ep.analysis.solvertype = "Monodomain" + >>> ep.stimulation = {"apex": Stimulation(node_ids=[1, 2, 3])} + >>> print(ep.analysis.solvertype) + Monodomain + """ - analysis: EPAnalysis = field(default_factory=lambda: EPAnalysis()) - """Generic analysis settings.""" - stimulation: AttrDict[str, Stimulation] = None - """Stimulation settings.""" + analysis: EPAnalysis = Field( + default_factory=EPAnalysis, description="Generic analysis settings" + ) + stimulation: dict[str, Stimulation] | None = Field( + default=None, description="Stimulation settings" + ) - _layers: dict = field( + layers: dict[str, Quantity] = Field( default_factory=lambda: { "percent_endo": Quantity(0.17, "dimensionless"), # thickness of endocardial layer "percent_mid": Quantity(0.41, "dimensionless"), # thickness of midmyocardial layer - } + }, + description="Layers for material assignment of the myocardium", ) - """Layers for material assignment of the myocardium.""" - _lambda: Quantity = Quantity(0.2, "dimensionless") # activate extracellular potential solve - """Intra to extracellular conductivity ratio.""" + lambda_ratio: Quantity = Field( + default=Quantity(0.2, "dimensionless"), + description="Intra to extracellular conductivity ratio", + ) -@dataclass(repr=False) -class Fibers(Settings): - """Class for keeping track of fiber settings.""" +class Fibers(BaseSettings): + """Class for keeping track of fiber settings. + + Defines fiber orientation parameters for ventricular myocardium, + including helical angles and transmural variations. + + Attributes + ---------- + alpha_endo : Quantity + Helical angle in endocardium. + alpha_epi : Quantity + Helical angle in epicardium. + beta_endo : Quantity + Angle to the outward transmural axis in endocardium. + beta_epi : Quantity + Angle to the outward transmural axis in epicardium. + beta_endo_septum : Quantity + Angle to the outward transmural axis in left septum. + beta_epi_septum : Quantity + Angle to the outward transmural axis in right septum. + + Examples + -------- + >>> from pint import Quantity + >>> fibers = Fibers( + ... alpha_endo=Quantity(-60, "degree"), + ... alpha_epi=Quantity(60, "degree"), + ... beta_endo=Quantity(0, "degree"), + ... beta_epi=Quantity(0, "degree"), + ... ) + >>> print(fibers.alpha_endo) + -60.0 degree + """ - alpha_endo: Quantity = 0 - "Helical angle in endocardium." - alpha_epi: Quantity = 0 - "Helical angle in epicardium." - beta_endo: Quantity = 0 - "Angle to the outward transmural axis of the heart in endocardium." - beta_epi: Quantity = 0 - "Angle to the outward transmural axis of the heart in epicardium." - beta_endo_septum: Quantity = 0 - "Angle to the outward transmural axis of the heart in left septum." - beta_epi_septum: Quantity = 0 - "Angle to the outward transmural axis of the heart in right septum." + alpha_endo: Quantity = Field( + default=Quantity(0, "degree"), description="Helical angle in endocardium" + ) + alpha_epi: Quantity = Field( + default=Quantity(0, "degree"), description="Helical angle in epicardium" + ) + beta_endo: Quantity = Field( + default=Quantity(0, "degree"), + description="Angle to the outward transmural axis in endocardium", + ) + beta_epi: Quantity = Field( + default=Quantity(0, "degree"), + description="Angle to the outward transmural axis in epicardium", + ) + beta_endo_septum: Quantity = Field( + default=Quantity(0, "degree"), + description="Angle to the outward transmural axis in left septum", + ) + beta_epi_septum: Quantity = Field( + default=Quantity(0, "degree"), + description="Angle to the outward transmural axis in right septum", + ) -@dataclass(repr=False) -class AtrialFiber(Settings): - """ - Class for keeping track of atrial fiber settings. +class AtrialFiber(BaseSettings): + """Class for keeping track of atrial fiber settings. Default parameters are from doi.org/10.1016/j.cma.2020.113468 for idealized geometry. + Defines atrial fiber bundle parameters and orientations. + + Attributes + ---------- + tau_mv : float + Mitral valve parameter. + tau_lpv : float + Left pulmonary vein parameter. + tau_rpv : float + Right pulmonary vein parameter. + tau_tv : float + Tricuspid valve parameter. + tau_raw : float + Right atrial wall parameter. + tau_ct_minus : float + Crista terminalis minus parameter. + tau_ct_plus : float + Crista terminalis plus parameter. + tau_icv : float + Inferior vena cava parameter. + tau_scv : float + Superior vena cava parameter. + tau_ib : float + Isthmus bundle parameter. + tau_ras : float + Right atrial septum parameter. + + Examples + -------- + >>> atrial = AtrialFiber(tau_mv=0.5, tau_tv=0.3) + >>> print(atrial.tau_mv) + 0.5 """ - tau_mv: float = 0 - tau_lpv: float = 0 - tau_rpv: float = 0 - - tau_tv: float = 0 - tau_raw: float = 0 - tau_ct_minus: float = 0 - tau_ct_plus: float = 0 - tau_icv: float = 0 - tau_scv: float = 0 - tau_ib: float = 0 - tau_ras: float = 0 - - -@dataclass(repr=False) -class Purkinje(Settings): - """Class for keeping track of Purkinje settings.""" - - node_id_origin_left: int = None - """Left Purkinje origin ID.""" - node_id_origin_right: int = None - """Right Purkinje origin id.""" - edgelen: Quantity = 0 - """Edge length.""" - ngen: Quantity = 0 - """Number of generations.""" - nbrinit: Quantity = 0 - """Number of beams from origin point.""" - nsplit: Quantity = 0 - """Number of splits at each leaf.""" - pmjtype: Quantity = 0 - """Purkinje muscle junction type.""" - pmjradius: Quantity = 0 - """Purkinje muscle junction radius.""" - pmjrestype: Quantity = Quantity(1) - """Purkinje muscle junction resistance type.""" - pmjres: Quantity = Quantity(0.001, "1/mS") # 1/mS - """Purkinje muscle junction resistance.""" + tau_mv: float = Field(default=0.0, description="Mitral valve parameter") + tau_lpv: float = Field(default=0.0, description="Left pulmonary vein parameter") + tau_rpv: float = Field(default=0.0, description="Right pulmonary vein parameter") + tau_tv: float = Field(default=0.0, description="Tricuspid valve parameter") + tau_raw: float = Field(default=0.0, description="Right atrial wall parameter") + tau_ct_minus: float = Field(default=0.0, description="Crista terminalis minus parameter") + tau_ct_plus: float = Field(default=0.0, description="Crista terminalis plus parameter") + tau_icv: float = Field(default=0.0, description="Inferior vena cava parameter") + tau_scv: float = Field(default=0.0, description="Superior vena cava parameter") + tau_ib: float = Field(default=0.0, description="Isthmus bundle parameter") + tau_ras: float = Field(default=0.0, description="Right atrial septum parameter") + + +class Purkinje(BaseSettings): + """Class for keeping track of Purkinje settings. + + Defines parameters for Purkinje network generation and electrical + properties including geometry, branching, and junction characteristics. + + Attributes + ---------- + node_id_origin_left : int | None + Left Purkinje origin node ID. + node_id_origin_right : int | None + Right Purkinje origin node ID. + edgelen : Quantity + Edge length for Purkinje segments. + ngen : Quantity + Number of generations in the network. + nbrinit : Quantity + Number of initial branches from origin. + nsplit : Quantity + Number of splits at each leaf. + pmjtype : Quantity + Purkinje muscle junction type identifier. + pmjradius : Quantity + Purkinje muscle junction radius. + pmjrestype : Quantity + Purkinje muscle junction resistance type. + pmjres : Quantity + Purkinje muscle junction resistance value. + + Examples + -------- + >>> from pint import Quantity + >>> purkinje = Purkinje( + ... node_id_origin_left=1, + ... node_id_origin_right=2, + ... edgelen=Quantity(1.0, "mm"), + ... ngen=Quantity(5, "dimensionless"), + ... ) + >>> print(purkinje.node_id_origin_left) + 1 + """ + + node_id_origin_left: int | None = Field(default=None, description="Left Purkinje origin ID") + node_id_origin_right: int | None = Field(default=None, description="Right Purkinje origin id") + edgelen: Quantity = Field(default=Quantity(0, "mm"), description="Edge length") + ngen: Quantity = Field( + default=Quantity(0, "dimensionless"), description="Number of generations" + ) + nbrinit: Quantity = Field( + default=Quantity(0, "dimensionless"), description="Number of beams from origin point" + ) + nsplit: Quantity = Field( + default=Quantity(0, "dimensionless"), description="Number of splits at each leaf" + ) + pmjtype: Quantity = Field( + default=Quantity(0, "dimensionless"), description="Purkinje muscle junction type" + ) + pmjradius: Quantity = Field( + default=Quantity(0, "mm"), description="Purkinje muscle junction radius" + ) + pmjrestype: Quantity = Field( + default=Quantity(1, "dimensionless"), description="Purkinje muscle junction resistance type" + ) + pmjres: Quantity = Field( + default=Quantity(0.001, "1/mS"), description="Purkinje muscle junction resistance" + ) class SimulationSettings: - """Class for keeping track of settings.""" + """Class for keeping track of settings. + + Attributes are conditionally created based on initialization parameters. + All parameters default to True, so all attributes exist by default. + """ + + # Type annotations for conditionally created attributes + # Note: These attributes will only exist if the corresponding boolean parameter is True + # All parameters default to True, so these attributes exist in the default case + mechanics: Mechanics # Exists when mechanics=True (default) + electrophysiology: Electrophysiology # Exists when electrophysiology=True (default) + fibers: Fibers # Exists when fiber=True (default) + atrial_fibers: AtrialFiber # Exists when fiber=True (default) + purkinje: Purkinje # Exists when purkinje=True (default) + stress_free: ZeroPressure # Exists when stress_free=True (default) def __init__( self, @@ -426,22 +859,39 @@ def __init__( pass def __repr__(self): - """Represent object as list of relevant attribute names.""" + """Represent object as list of relevant attribute names. + + Returns + ------- + str + String representation showing the class name and active + settings attribute names. + + Examples + -------- + >>> settings = SimulationSettings() + >>> print(repr(settings)) + SimulationSettings + mechanics + electrophysiology + fibers + atrial_fibers + purkinje + stress_free + """ repr_str = "\n ".join( - [attr for attr in self.__dict__ if isinstance(getattr(self, attr), Settings)] + [attr for attr in self.__dict__ if isinstance(getattr(self, attr), BaseSettings)] ) repr_str = self.__class__.__name__ + "\n " + repr_str return repr_str - def save(self, filename: pathlib.Path, remove_units: bool = False): + def save(self, filename: pathlib.Path): """Save simulation settings to disk. Parameters ---------- filename : pathlib.Path Path to target .json or .yml file - remove_units : bool, optional - Flag indicating whether to remove units before writing, by default False Examples -------- @@ -459,20 +909,23 @@ def save(self, filename: pathlib.Path, remove_units: bool = False): if filename.suffix not in [".yml", ".json"]: raise ValueError(f"Data format {filename.suffix} not supported") - # serialize each of the settings. + # Serialize each of the settings using Pydantic v2's enhanced model_dump serialized_settings = {} for attribute_name in self.__dict__.keys(): - if not isinstance(getattr(self, attribute_name), Settings): + if not isinstance(getattr(self, attribute_name), BaseSettings): continue else: - setting: Settings = getattr(self, attribute_name) - serialized_settings[attribute_name] = setting.serialize(remove_units=remove_units) + setting: BaseSettings = getattr(self, attribute_name) + # Use the simplified model dump method (no unit removal) + serialized_settings[attribute_name] = setting.model_dump( + mode="json", exclude_none=False + ) serialized_settings = {"Simulation Settings": serialized_settings} with open(filename, "w") as f: if filename.suffix == ".yml": - # NOTE: this suppress writing of tags from AttrDict + # Serialize settings using modern Pydantic serialization yaml.dump(json.loads(json.dumps(serialized_settings)), f, sort_keys=False) elif filename.suffix == ".json": @@ -508,41 +961,141 @@ def load(self, filename: pathlib.Path): if not isinstance(filename, pathlib.Path): filename = pathlib.Path(filename) - with open(filename, "r") as f: - if filename.suffix == ".json": - data = json.load(f) - if filename.suffix == ".yml": - data = yaml.load(f, Loader=yaml.SafeLoader) - settings = data["Simulation Settings"] + # Load file data with proper error handling + try: + with open(filename, "r", encoding="utf-8") as f: + if filename.suffix == ".json": + data = json.load(f) + elif filename.suffix == ".yml": + data = yaml.load(f, Loader=yaml.SafeLoader) + else: + raise ValueError(f"Unsupported file format: {filename.suffix}") + except FileNotFoundError as e: + LOGGER.error(f"Settings file not found: {filename}") + raise FileNotFoundError(f"Settings file not found: {filename}") from e + except (json.JSONDecodeError, yaml.YAMLError) as e: + LOGGER.error(f"Failed to parse settings file {filename}: {e}") + raise ValueError(f"Invalid file format in {filename}: {e}") from e + + settings_data = data.get("Simulation Settings", {}) + if not settings_data: + LOGGER.warning("No 'Simulation Settings' found in file") + return + + # Unit registry kept for backward compatibility with external code + # ureg = UnitRegistry() # Commented out - no longer needed + + try: + # Use streamlined approach - Pydantic handles all validation automatically + self._load_settings_section("mechanics", settings_data, Mechanics) + self._load_settings_section("stress_free", settings_data, ZeroPressure) + self._load_settings_section("electrophysiology", settings_data, Electrophysiology) + self._load_settings_section("fibers", settings_data, Fibers) + self._load_settings_section("atrial_fibers", settings_data, AtrialFiber) + self._load_settings_section("purkinje", settings_data, Purkinje) + + except ValidationError as e: + LOGGER.error(f"Validation error while loading settings: {e}") + raise ValueError(f"Invalid settings data: {e}") from e + except Exception as e: + LOGGER.error(f"Unexpected error loading settings: {e}") + raise RuntimeError(f"Failed to load settings: {e}") from e + + def _load_settings_section( + self, section_name: str, settings_data: dict[str, Any], model_class: type[BaseSettings] + ) -> None: + """Load a specific settings section using Pydantic v2 validation. + + This helper method streamlines the loading process by using Pydantic's + automatic validation and type conversion. It pre-processes nested data + to convert string quantities to Quantity objects before validation. + + Parameters + ---------- + section_name : str + Name of the settings section to load. + settings_data : dict[str, Any] + Complete settings data dictionary. + model_class : type[BaseSettings] + Pydantic model class to validate against. + """ + if section_name in settings_data and hasattr(self, section_name): + section_data = settings_data[section_name].copy() + + # Pre-process nested data to convert string quantities + section_data = self._convert_quantities_recursive(section_data) + + # Let Pydantic handle all validation and type conversion automatically + validated_model = model_class.model_validate(section_data) + setattr(self, section_name, validated_model) + + def _convert_quantities_recursive(self, data: Any) -> Any: + """Recursively convert string quantities to Quantity objects in nested data. + + This helper method processes nested dictionaries and lists to convert + string representations of quantities to actual Quantity objects before + Pydantic validation. This ensures proper handling of nested models. + + Parameters + ---------- + data : Any + Data structure to process (dict, list, or primitive value). + + Returns + ------- + Any + Processed data with string quantities converted to Quantity objects. + """ + if isinstance(data, dict): + return {key: self._convert_quantities_recursive(value) for key, value in data.items()} + elif isinstance(data, list): + return [self._convert_quantities_recursive(item) for item in data] + elif isinstance(data, str): + # Try to parse as a quantity if it looks like one + if self._looks_like_quantity(data): + try: + return ureg(data) + except Exception: + # If parsing fails, return the original string + return data + return data + else: + # Return primitive values unchanged + return data + + def _looks_like_quantity(self, value: str) -> bool: + """Check if a string looks like a quantity that can be parsed. + + Parameters + ---------- + value : str + String value to check. - # unit registry to convert back to Quantity object - ureg = UnitRegistry() + Returns + ------- + bool + True if the string appears to be a quantity representation. + """ + # Simple heuristic: contains a space and has numeric part + if " " not in value: + return False + + parts = value.split() + if len(parts) < 2: + return False + # Check if first part is numeric try: - attribute_name = "mechanics" - _deserialize_quantity(settings[attribute_name], ureg) - # assign values to each respective attribute - analysis = Analysis() - analysis.set_values(settings[attribute_name]["analysis"]) - boundary_conditions = BoundaryConditions() - boundary_conditions.set_values(settings[attribute_name]["boundary_conditions"]) - system_model = SystemModel() - system_model.set_values(settings[attribute_name]["system"]) - self.mechanics.analysis = analysis - self.mechanics.boundary_conditions = boundary_conditions - self.mechanics.system = system_model - - attribute_name = "stress_free" - _deserialize_quantity(settings[attribute_name], ureg) - analysis = AnalysisZeroPressure() - analysis.set_values(settings[attribute_name]["analysis"]) - self.stress_free.analysis = analysis - - except KeyError: - LOGGER.error("Failed to load mechanics settings.") + float(parts[0]) + return True + except ValueError: + return False def load_defaults(self): - """Load default simulation settings. + """Load default simulation settings using Pydantic model initialization. + + This method properly initializes all settings with default values using + Pydantic's built-in validation and type conversion capabilities. Examples -------- @@ -557,51 +1110,66 @@ def load_defaults(self): >>> settings.load_defaults() >>> settings.mechanics.analysis Analysis: - end_time: 3000.0 millisecond - dtmin: 10.0 millisecond - dtmax: 10.0 millisecond - dt_d3plot: 50.0 millisecond - dt_icvout: 1.0 millisecond - global_damping: 0.5 / millisecond + end_time: 800.0 millisecond + dtmin: 5.0 millisecond + dtmax: 5.0 millisecond + dt_d3plot: 20.0 millisecond + dt_icvout: 5.0 millisecond + global_damping: 0.1 / millisecond """ - # TODO: move to Settings class - for attr in self.__dict__: - if isinstance(getattr(self, attr), Mechanics): - analysis = Analysis() - analysis.set_values(mech_defaults.analysis) - boundary_conditions = BoundaryConditions() - boundary_conditions.set_values(mech_defaults.boundary_conditions) - system_model = SystemModel() - system_model.set_values(mech_defaults.system_model) - - self.mechanics.analysis = analysis - self.mechanics.boundary_conditions = boundary_conditions - self.mechanics.system = system_model - - if isinstance(getattr(self, attr), ZeroPressure): - analysis = AnalysisZeroPressure() - analysis.set_values(zero_pressure_defaults.analysis) - self.stress_free.analysis = analysis - - if isinstance(getattr(self, attr), Electrophysiology): - analysis = EPAnalysis() - analysis.set_values(ep_defaults.analysis) - self.electrophysiology.analysis = analysis - self.electrophysiology.stimulation: AttrDict[str, Stimulation] = AttrDict() - for key in ep_defaults.stimulation.keys(): - system_model = Stimulation() - system_model.set_values(ep_defaults.stimulation[key]) - self.electrophysiology.stimulation[key] = system_model - # TODO: add stim params, monodomain/bidomain/eikonal,cellmodel - # TODO: add settings for purkinje fibers and epmecha - if isinstance(getattr(self, attr), Fibers): - self.fibers.set_values(fibers_defaults.angles) - if isinstance(getattr(self, attr), Purkinje): - self.purkinje.set_values(purkinje_defaults.build) - if isinstance(getattr(self, attr), AtrialFiber): - self.atrial_fibers.set_values(fibers_defaults.la_bundle) - self.atrial_fibers.set_values(fibers_defaults.ra_bundle) + try: + # Load mechanics defaults using Pydantic model initialization + if hasattr(self, "mechanics") and isinstance(self.mechanics, Mechanics): + self.mechanics.analysis = Analysis(**mech_defaults.analysis) + self.mechanics.boundary_conditions = BoundaryConditions( + **mech_defaults.boundary_conditions + ) + self.mechanics.system = SystemModel(**mech_defaults.system_model) + + # Load zero pressure defaults + if hasattr(self, "stress_free") and isinstance(self.stress_free, ZeroPressure): + self.stress_free.analysis = AnalysisZeroPressure(**zero_pressure_defaults.analysis) + + # Load electrophysiology defaults + if hasattr(self, "electrophysiology") and isinstance( + self.electrophysiology, Electrophysiology + ): + self.electrophysiology.analysis = EPAnalysis(**ep_defaults.analysis) + + # Create stimulation dictionary with Pydantic validation + stimulation_dict = {} + for key, stim_data in ep_defaults.stimulation.items(): + stimulation_dict[key] = Stimulation(**stim_data) + self.electrophysiology.stimulation = stimulation_dict + + # Load fiber defaults + if hasattr(self, "fibers") and isinstance(self.fibers, Fibers): + # Update fibers with defaults - handle properly based on Fibers model structure + for field_name, value in fibers_defaults.angles.items(): + if hasattr(self.fibers, field_name): + setattr(self.fibers, field_name, value) + + # Load Purkinje defaults + if hasattr(self, "purkinje") and isinstance(self.purkinje, Purkinje): + # Update Purkinje with defaults - handle properly based on Purkinje model structure + for field_name, value in purkinje_defaults.build.items(): + if hasattr(self.purkinje, field_name): + setattr(self.purkinje, field_name, value) + + # Load atrial fiber defaults + if hasattr(self, "atrial_fibers") and isinstance(self.atrial_fibers, AtrialFiber): + # Update atrial fibers with defaults - handle both la_bundle and ra_bundle + for field_name, value in fibers_defaults.la_bundle.items(): + if hasattr(self.atrial_fibers, field_name): + setattr(self.atrial_fibers, field_name, value) + for field_name, value in fibers_defaults.ra_bundle.items(): + if hasattr(self.atrial_fibers, field_name): + setattr(self.atrial_fibers, field_name, value) + + except Exception as e: + LOGGER.error(f"Failed to load default settings: {e}") + raise RuntimeError(f"Failed to initialize settings with defaults: {e}") from e def to_consistent_unit_system(self): """Convert all settings to consistent unit-system ["MPa", "mm", "N", "ms", "g"]. @@ -625,26 +1193,43 @@ def to_consistent_unit_system(self): attributes = [ getattr(self, attr) for attr in self.__dict__ - if isinstance(getattr(self, attr), Settings) + if isinstance(getattr(self, attr), BaseSettings) ] for attr in attributes: - if isinstance(attr, Settings): + if isinstance(attr, BaseSettings): attr.to_consistent_unit_system() return def get_ventricle_fiber_rotation(self, method: Literal["LSDYNA", "D-RBM"]) -> dict: - """Get rotation angles from settings. + """Get rotation angles from fiber settings. + + Extracts fiber orientation angles from the configured fiber settings + and formats them according to the specified fiber generation method. Parameters ---------- - method : Literal["LSDYNA", "D - Fiber rule based methods + method : Literal["LSDYNA", "D-RBM"] + Fiber rule-based method for extracting rotation angles. + - "LSDYNA": LS-DYNA fiber generation format + - "D-RBM": Discrete Rule-Based Method format Returns ------- dict - rotation angles alpha and beta + Dictionary containing rotation angles (alpha and beta) formatted + for the specified method. Keys and structure depend on the method: + - LSDYNA: "alpha", "beta", "beta_septum" keys + - D-RBM: "alpha_left", "alpha_right", "alpha_ot", "beta_left", + "beta_right", "beta_ot" keys + + Examples + -------- + >>> settings = SimulationSettings() + >>> settings.load_defaults() + >>> rotation = settings.get_ventricle_fiber_rotation("LSDYNA") + >>> print(rotation["alpha"]) + [-60.0, 60.0] """ if method == "LSDYNA": rotation = { @@ -685,48 +1270,6 @@ def get_ventricle_fiber_rotation(self, method: Literal["LSDYNA", "D-RBM"]) -> di return rotation -def _remove_units_from_dictionary(d: dict): - """Replace Quantity with value in a nested dictionary (removes units).""" - for k, v in d.items(): - if isinstance(v, (dict, AttrDict)): - _remove_units_from_dictionary(v) - if isinstance(v, Quantity): - d[k] = d[k].m - return d - - -def _serialize_quantity(d: dict, remove_units: bool = False): - """Serialize Quantity such that Quantity objects are replaced by string.""" - for k, v in d.items(): - # if isinstance(v, AttrDict): - # v = dict(v) # cast to regular dict - if isinstance(v, (dict, AttrDict)): - _serialize_quantity(v, remove_units=remove_units) - if isinstance(v, Quantity): - if remove_units: - d[k] = str(d[k].m) - else: - d[k] = str(d[k]) - return d - - -def _deserialize_quantity(d: dict, ureg: UnitRegistry): - """Deserialize string such that " " is replaced by Quantity(value, units).""" - for k, v in d.items(): - if isinstance(v, dict): - _deserialize_quantity(v, ureg) - if isinstance(v, str): - if isinstance(d[k], str): - try: - float(d[k].split()[0]) - q = ureg(d[k]) - except ValueError: - # failed to convert to quantity - continue - d[k] = q - return d - - # desired consistent unit system is: # ["MPa", "mm", "N", "ms", "g"] # Time: ms @@ -759,7 +1302,21 @@ def _deserialize_quantity(d: dict, ureg: UnitRegistry): def _get_consistent_units_str(dimensions: set): - """Get consistent units formatted as string.""" + """Get consistent units formatted as string. + + Converts dimensionality to the PyAnsys Heart consistent unit system string + representation based on the defined base quantities and derived units. + + Parameters + ---------- + dimensions : set + Set of dimensions from a Quantity object. + + Returns + ------- + str + String representation of consistent units for the given dimensions. + """ if dimensions in _derived[0]: _to_units = _derived[1][_derived[0].index(dimensions)] return _to_units @@ -773,7 +1330,28 @@ def _get_consistent_units_str(dimensions: set): def _windows_to_wsl_path(windows_path: str): - """Convert Windows to WSL path.""" + r"""Convert Windows path to WSL-compatible path format. + + Handles conversion from Windows drive paths and WSL localhost paths + to proper Unix-style paths for use within Windows Subsystem for Linux. + + Parameters + ---------- + windows_path : str + Windows path to convert. + + Returns + ------- + str | None + WSL-compatible path string, or None if conversion not applicable. + + Examples + -------- + >>> _windows_to_wsl_path(r"C:\Users\example") + '/mnt/c/Users/example' + >>> _windows_to_wsl_path(r"\\wsl.localhost\Ubuntu\home") + '/Ubuntu/home' + """ win_path = Path(windows_path) if isinstance(win_path, pathlib.PosixPath): return None @@ -791,11 +1369,51 @@ def _windows_to_wsl_path(windows_path: str): class DynaSettings: - """Class for collecting, managing, and validating LS-DYNA settings.""" + """Class for collecting, managing, and validating LS-DYNA settings. + + This class provides configuration management for LS-DYNA simulations, + including executable paths, parallelization settings, platform-specific + configurations, and command-line argument generation. + + Attributes + ---------- + lsdyna_path : pathlib.Path + Path to LS-DYNA executable. + dynatype : str + Type of LS-DYNA executable (smp, intelmpi, platformmpi, msmpi). + num_cpus : int + Number of CPUs requested for parallel execution. + platform : str + Platform for LS-DYNA execution (windows, wsl, linux). + dyna_options : str + Additional command line options for LS-DYNA. + mpi_options : str + Additional MPI options for parallel execution. + + Examples + -------- + >>> dyna_settings = DynaSettings( + ... lsdyna_path="lsdyna.exe", dynatype="intelmpi", num_cpus=4, platform="windows" + ... ) + >>> commands = dyna_settings.get_commands("input.k") + """ @staticmethod def _get_available_mpi_exe(): - """Find whether mpiexec or mpirun are available.""" + """Find whether mpiexec or mpirun are available. + + Searches for MPI executables in PATH, preferring mpirun over mpiexec. + + Returns + ------- + str + Path to available MPI executable. + + Raises + ------ + MPIProgamNotFoundError + If neither mpirun nor mpiexec are found in PATH. + """ # preference for mpirun if it is added to PATH. mpiexec is the fallback option. if shutil.which("mpirun"): return shutil.which("mpirun") @@ -873,9 +1491,12 @@ def __init__( return - def get_commands(self, path_to_input: pathlib.Path) -> List[str]: + def get_commands(self, path_to_input: pathlib.Path) -> list[str]: """Get command line arguments from the defined settings. + Builds platform-specific command line arguments for running LS-DYNA + with the configured settings including MPI and parallelization options. + Parameters ---------- path_to_input : pathlib.Path @@ -883,8 +1504,19 @@ def get_commands(self, path_to_input: pathlib.Path) -> List[str]: Returns ------- - List[str] - List of strings of each of the commands. + list[str] + List of command line arguments for executing LS-DYNA. + + Raises + ------ + WSLNotFoundError + If WSL platform is specified but wsl.exe is not found. + + Examples + -------- + >>> dyna_settings = DynaSettings(dynatype="smp", num_cpus=4) + >>> commands = dyna_settings.get_commands(Path("input.k")) + >>> print(commands[0]) # LS-DYNA executable path """ if self.platform == "wsl": mpi_exe = "mpirun" @@ -974,7 +1606,15 @@ def get_commands(self, path_to_input: pathlib.Path) -> List[str]: return commands def _modify_from_global_settings(self): - """Set DynaSettings based on globally defined settings for PyAnsys-Heart.""" + """Set DynaSettings based on globally defined settings for PyAnsys-Heart. + + Checks for PYANSYS_HEART environment variables and updates settings + accordingly. Supported environment variables: + - PYANSYS_HEART_LSDYNA_PATH: Path to LS-DYNA executable + - PYANSYS_HEART_LSDYNA_PLATFORM: Execution platform + - PYANSYS_HEART_LSDYNA_TYPE: LS-DYNA executable type + - PYANSYS_HEART_NUM_CPU: Number of CPUs for parallel execution + """ keys = [key for key in os.environ.keys() if "PYANSYS_HEART" in key] LOGGER.debug(f"PYANSYS_HEART Environment variables: {keys}") self.lsdyna_path = os.getenv("PYANSYS_HEART_LSDYNA_PATH", self.lsdyna_path) @@ -984,5 +1624,11 @@ def _modify_from_global_settings(self): return def __repr__(self): - """Represent self as string.""" + """Represent self as YAML-formatted string. + + Returns + ------- + str + YAML representation of the DynaSettings object attributes. + """ return yaml.dump(vars(self), allow_unicode=True, default_flow_style=False) diff --git a/src/ansys/health/heart/writer/_control_volume.py b/src/ansys/health/heart/writer/_control_volume.py index cc9391695..f438351ac 100644 --- a/src/ansys/health/heart/writer/_control_volume.py +++ b/src/ansys/health/heart/writer/_control_volume.py @@ -22,12 +22,38 @@ """Module to system model.""" from dataclasses import dataclass +from typing import Any from ansys.health.heart.models import BiVentricle, FourChamber, HeartModel, LeftVentricle from ansys.health.heart.parts import Chamber from ansys.health.heart.writer.define_function_templates import _define_function_0d_system +def _convert_quantities_to_magnitudes(obj: Any) -> Any: + """Recursively convert Quantity objects to their magnitudes in nested dictionaries. + + Parameters + ---------- + obj : Any + Object that may contain Quantity objects. + + Returns + ------- + Any + Object with Quantity values converted to magnitudes. + """ + from pint import Quantity + + if isinstance(obj, dict): + return {key: _convert_quantities_to_magnitudes(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [_convert_quantities_to_magnitudes(item) for item in obj] + elif isinstance(obj, Quantity): + return obj.magnitude + else: + return obj + + @dataclass class CVInteraction: """Template to define control volume interaction.""" @@ -43,7 +69,9 @@ def _define_function_keyword(self): if self.flow_name == "closed-loop": return "" else: - return _define_function_0d_system(self.lcid, self.flow_name, self.parameters) + # Convert any Quantity objects to magnitudes before passing to template + parameters_magnitudes = _convert_quantities_to_magnitudes(self.parameters) + return _define_function_0d_system(self.lcid, self.flow_name, parameters_magnitudes) @dataclass diff --git a/src/ansys/health/heart/writer/base_writer.py b/src/ansys/health/heart/writer/base_writer.py index 15507414e..45a800ae5 100644 --- a/src/ansys/health/heart/writer/base_writer.py +++ b/src/ansys/health/heart/writer/base_writer.py @@ -127,14 +127,14 @@ def __init__(self, model: HeartModel, settings: SimulationSettings = None) -> No return - def _get_subsettings(self) -> list[sett.Settings]: + def _get_subsettings(self) -> list[sett.BaseSettings]: """Get subsettings from the settings object.""" import ansys.health.heart.settings.settings as sett subsettings_classes = [ getattr(self.settings, attr).__class__ for attr in self.settings.__dict__ - if isinstance(getattr(self.settings, attr), sett.Settings) + if isinstance(getattr(self.settings, attr), sett.BaseSettings) ] return subsettings_classes diff --git a/src/ansys/health/heart/writer/ep_writer.py b/src/ansys/health/heart/writer/ep_writer.py index 42d03fe3c..d535e6583 100644 --- a/src/ansys/health/heart/writer/ep_writer.py +++ b/src/ansys/health/heart/writer/ep_writer.py @@ -425,8 +425,8 @@ def _update_parts_cellmodels(self) -> None: def _create_myocardial_nodeset_layers(self) -> tuple[int, int, int]: """Create myocardial node set layers.""" - percent_endo = self.settings.electrophysiology._layers["percent_endo"].m - percent_mid = self.settings.electrophysiology._layers["percent_mid"].m + percent_endo = self.settings.electrophysiology.layers["percent_endo"].m + percent_mid = self.settings.electrophysiology.layers["percent_mid"].m values = self.model.mesh.point_data["transmural"] # Values from experimental data. See: # https://www.frontiersin.org/articles/10.3389/fphys.2019.00580/full diff --git a/src/ansys/health/heart/writer/mechanics_writer.py b/src/ansys/health/heart/writer/mechanics_writer.py index 8f6eee175..3b7864e3a 100644 --- a/src/ansys/health/heart/writer/mechanics_writer.py +++ b/src/ansys/health/heart/writer/mechanics_writer.py @@ -23,10 +23,11 @@ import copy from enum import Enum -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional import numpy as np import pandas as pd +from pint import Quantity import pyvista as pv from ansys.dyna.core.keywords import keywords @@ -128,8 +129,7 @@ def update(self, dynain_name: Optional[str] = None, robin_bcs: list[Callable] = self._add_pericardium_bc() # for control volume - system_settings = copy.deepcopy(self.settings.mechanics.system) - system_settings._remove_units() + system_settings = self.settings.mechanics.system if system_settings.name == "open-loop": lcid = self.get_unique_curve_id() @@ -1075,12 +1075,10 @@ def _add_export_controls(self, dt_output_d3plot: float = 0.5) -> None: # self.kw_database.main.append(keywords.DatabaseExtentBinary(neiph=27, strflg=1, maxint=0)) # add binout for post-process - settings = copy.deepcopy(self.settings.stress_free) - settings._remove_units() + stress_free_settings = self.settings.stress_free + dt_nodout = _get_magnitude(stress_free_settings.analysis.dt_nodout) - self.kw_database.main.append( - keywords.DatabaseNodout(dt=settings.analysis.dt_nodout, binary=2) - ) + self.kw_database.main.append(keywords.DatabaseNodout(dt=dt_nodout, binary=2)) # write for all nodes in nodout nodeset_id = self.get_unique_nodeset_id() @@ -1094,24 +1092,24 @@ def _add_export_controls(self, dt_output_d3plot: float = 0.5) -> None: def _add_solution_controls(self) -> None: """Rewrite the method for the zerop simulation.""" - settings = copy.deepcopy(self.settings.stress_free) - settings._remove_units() + stress_free_settings = self.settings.stress_free + + # Extract magnitude values for LS-DYNA keywords + end_time = _get_magnitude(stress_free_settings.analysis.end_time) + dtmin = _get_magnitude(stress_free_settings.analysis.dtmin) + dtmax = _get_magnitude(stress_free_settings.analysis.dtmax) - self.kw_database.main.append(keywords.ControlTermination(endtim=settings.analysis.end_time)) + self.kw_database.main.append(keywords.ControlTermination(endtim=end_time)) self.kw_database.main.append(keywords.ControlImplicitDynamics(imass=0)) # add auto step controls self.kw_database.main.append( - keywords.ControlImplicitAuto( - iauto=1, dtmin=settings.analysis.dtmin, dtmax=settings.analysis.dtmax - ) + keywords.ControlImplicitAuto(iauto=1, dtmin=dtmin, dtmax=dtmax) ) # add general implicit controls - self.kw_database.main.append( - keywords.ControlImplicitGeneral(imflag=1, dt0=settings.analysis.dtmax) - ) + self.kw_database.main.append(keywords.ControlImplicitGeneral(imflag=1, dt0=dtmax)) # add implicit solution controls self.kw_database.main.append( @@ -1246,3 +1244,19 @@ def _add_enddiastolic_pressure_bc(self) -> None: continue return + + +def _get_magnitude(value: Any) -> Any: + """Extract magnitude from Quantity objects, return other values unchanged. + + Parameters + ---------- + value : Any + Value to extract magnitude from. + + Returns + ------- + Any + Magnitude if value is a Quantity, otherwise the original value. + """ + return value.magnitude if isinstance(value, Quantity) else value diff --git a/tests/heart/assets/post/main/simulation_settings.yml b/tests/heart/assets/post/main/simulation_settings.yml index 6d84f8ea3..686fca44d 100644 --- a/tests/heart/assets/post/main/simulation_settings.yml +++ b/tests/heart/assets/post/main/simulation_settings.yml @@ -7,28 +7,6 @@ Simulation Settings: dt_d3plot: 20.0 millisecond dt_icvout: 5.0 millisecond global_damping: 0.1 / millisecond - material: - myocardium: - isotropic: - rho: 0.001 gram / millimeter ** 3 - nu: 0.499 - k1: 0.0023599999999999997 megapascal - k2: 1.75 dimensionless - anisotropic: - k1f: 0.00049 megapascal - k2f: 9.01 dimensionless - active: - beat_time: 800 millisecond - taumax: 0.125 megapascal - ss: 0.0 - sn: 0.0 - passive: - type: NeoHook - rho: 0.001 gram / millimeter ** 3 - itype: -1 - mu1: 0.1 megapascal - alpha1: 2 - cap: null boundary_conditions: robin: ventricle: @@ -71,30 +49,6 @@ Simulation Settings: initial_value: part: 0.0019998358112249997 megapascal electrophysiology: - material: - myocardium: - velocity_fiber: 0.7 millimeter / millisecond - velocity_sheet: 0.2 millimeter / millisecond - velocity_sheet_normal: 0.2 millimeter / millisecond - sigma_fiber: 0.5 millisiemens / millimeter - sigma_sheet: 0.1 millisiemens / millimeter - sigma_sheet_normal: 0.1 millisiemens / millimeter - sigma_passive: 1.0 millisiemens / millimeter - beta: 140 / millimeter - cm: 0.01 microfarad / millimeter ** 2 - lambda: 0.2 dimensionless - percent_endo: 0.17 dimensionless - percent_mid: 0.41 dimensionless - atrium: null - cap: null - beam: - velocity: 1 millimeter / millisecond - sigma: 1 millisiemens / millimeter - beta: 140 / millimeter - cm: 0.01 microfarad / millimeter ** 2 - lambda: 0.2 dimensionless - pmjrestype: 1 dimensionless - pmjres: 0.001 / millisiemens analysis: end_time: 800 millisecond dtmin: 0.0 millisecond diff --git a/tests/heart/assets/post/zerop/simulation_settings.yml b/tests/heart/assets/post/zerop/simulation_settings.yml index 6d84f8ea3..686fca44d 100644 --- a/tests/heart/assets/post/zerop/simulation_settings.yml +++ b/tests/heart/assets/post/zerop/simulation_settings.yml @@ -7,28 +7,6 @@ Simulation Settings: dt_d3plot: 20.0 millisecond dt_icvout: 5.0 millisecond global_damping: 0.1 / millisecond - material: - myocardium: - isotropic: - rho: 0.001 gram / millimeter ** 3 - nu: 0.499 - k1: 0.0023599999999999997 megapascal - k2: 1.75 dimensionless - anisotropic: - k1f: 0.00049 megapascal - k2f: 9.01 dimensionless - active: - beat_time: 800 millisecond - taumax: 0.125 megapascal - ss: 0.0 - sn: 0.0 - passive: - type: NeoHook - rho: 0.001 gram / millimeter ** 3 - itype: -1 - mu1: 0.1 megapascal - alpha1: 2 - cap: null boundary_conditions: robin: ventricle: @@ -71,30 +49,6 @@ Simulation Settings: initial_value: part: 0.0019998358112249997 megapascal electrophysiology: - material: - myocardium: - velocity_fiber: 0.7 millimeter / millisecond - velocity_sheet: 0.2 millimeter / millisecond - velocity_sheet_normal: 0.2 millimeter / millisecond - sigma_fiber: 0.5 millisiemens / millimeter - sigma_sheet: 0.1 millisiemens / millimeter - sigma_sheet_normal: 0.1 millisiemens / millimeter - sigma_passive: 1.0 millisiemens / millimeter - beta: 140 / millimeter - cm: 0.01 microfarad / millimeter ** 2 - lambda: 0.2 dimensionless - percent_endo: 0.17 dimensionless - percent_mid: 0.41 dimensionless - atrium: null - cap: null - beam: - velocity: 1 millimeter / millisecond - sigma: 1 millisiemens / millimeter - beta: 140 / millimeter - cm: 0.01 microfarad / millimeter ** 2 - lambda: 0.2 dimensionless - pmjrestype: 1 dimensionless - pmjres: 0.001 / millisiemens analysis: end_time: 800 millisecond dtmin: 0.0 millisecond diff --git a/tests/heart/post/test_laplace.py b/tests/heart/post/test_laplace.py index 818247027..0d6c1bbf0 100644 --- a/tests/heart/post/test_laplace.py +++ b/tests/heart/post/test_laplace.py @@ -47,14 +47,14 @@ def _set_env_vars(monkeypatch): def test_compute_la_fiber_cs(_set_env_vars): dir = os.path.join(get_assets_folder(), "post", "la_fiber") - setting = AtrialFiber() - setting.set_values( - { + setting = AtrialFiber( + **{ "tau_mv": 0.65, "tau_lpv": 0.1, "tau_rpv": 0.65, } ) + input_grid = pv.read(os.path.join(dir, "la_input.vtu")) la_endo = pv.read(os.path.join(dir, "la_endo.vtk")) @@ -72,9 +72,8 @@ def test_compute_la_fiber_cs(_set_env_vars): def test_compute_ra_fiber_cs(_set_env_vars): dir = os.path.join(get_assets_folder(), "post", "ra_fiber") - setting = AtrialFiber() - setting.set_values( - { + setting = AtrialFiber( + **{ "tau_tv": 0.9, "tau_raw": 0.55, "tau_ct_minus": -0.18, diff --git a/tests/heart/post/test_postprocess.py b/tests/heart/post/test_postprocess.py index a63070252..65d20dace 100644 --- a/tests/heart/post/test_postprocess.py +++ b/tests/heart/post/test_postprocess.py @@ -49,11 +49,12 @@ def get_left_ventricle(): @pytest.mark.requires_dpf -def test_compute_thickness(get_left_ventricle): +def test_compute_thickness(get_left_ventricle, monkeypatch): test_dir = get_left_ventricle[0] model = get_left_ventricle[1] d3plot = os.path.join(os.path.join(test_dir, "main", "d3plot")) + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") s = AhaStrainCalculator(model, d3plot) lines = s._compute_thickness_lines() assert len(lines) == 2 @@ -63,7 +64,8 @@ def test_compute_thickness(get_left_ventricle): @pytest.mark.requires_dpf -def test_compute_myocardial_strain(get_left_ventricle): +def test_compute_myocardial_strain(get_left_ventricle, monkeypatch): + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") test_dir = get_left_ventricle[0] model = get_left_ventricle[1] d3plot = os.path.join(os.path.join(test_dir, "main", "d3plot")) @@ -74,7 +76,8 @@ def test_compute_myocardial_strain(get_left_ventricle): @pytest.mark.requires_dpf -def test_compute_aha_strain(get_left_ventricle): +def test_compute_aha_strain(get_left_ventricle, monkeypatch): + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") test_dir = get_left_ventricle[0] model = get_left_ventricle[1] d3plot = os.path.join(os.path.join(test_dir, "main", "d3plot")) @@ -86,8 +89,9 @@ def test_compute_aha_strain(get_left_ventricle): @pytest.mark.requires_dpf -def test_plot_aha_bullseye(): +def test_plot_aha_bullseye(monkeypatch): """Test plotting AHA bullseye plot.""" + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") # Create the fake data data = np.arange(17) + 1 @@ -105,7 +109,8 @@ def test_plot_aha_bullseye(): @pytest.mark.requires_dpf -def test_zerop_post(get_left_ventricle): +def test_zerop_post(get_left_ventricle, monkeypatch): + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") test_dir = get_left_ventricle[0] model = get_left_ventricle[1] dct = zerop_post(os.path.join(test_dir, "zerop"), model) @@ -124,11 +129,13 @@ def test_zerop_post(get_left_ventricle): @pytest.mark.requires_dpf class TestSystemModelPost: @pytest.fixture - def system_model(self, get_left_ventricle): + def system_model(self, get_left_ventricle, monkeypatch): + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") test_dir = get_left_ventricle[0] return SystemModelPost(os.path.join(test_dir, "main")) - def test_plot_pv_loop(self, system_model): + def test_plot_pv_loop(self, system_model, monkeypatch): + monkeypatch.setenv("ANSYS_DPF_ACCEPT_LA", "Y") ef = system_model.get_ejection_fraction() fig = system_model.plot_pv_loop(ef=ef) fig.savefig("pv_{0:d}.png".format(0)) diff --git a/tests/heart/settings/test_quantity_serialization.py b/tests/heart/settings/test_quantity_serialization.py new file mode 100644 index 000000000..7f87c11e5 --- /dev/null +++ b/tests/heart/settings/test_quantity_serialization.py @@ -0,0 +1,330 @@ +# Copyright (C) 2023 - 2025 ANSYS, Inc. and/or its affiliates. +# SPDX-License-Identifier: MIT +# +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Test Quantity serialization and validation with Pydantic v2.""" + +import json +from pathlib import Path +import tempfile + +from pint import Quantity +import pytest +import yaml + +from ansys.health.heart.settings.settings import ( + Analysis, + Mechanics, + SimulationSettings, + Stimulation, +) + + +class TestQuantitySerializationPydantic: + """Test Pydantic v2 based Quantity serialization and validation.""" + + def test_quantity_field_serializer_json(self): + """Test that Quantity fields are properly serialized to JSON strings.""" + analysis = Analysis( + end_time=Quantity(1000, "ms"), + dtmin=Quantity(0.1, "ms"), + dtmax=Quantity(10, "ms"), + global_damping=Quantity(0.5, "1/s"), + ) + + # Test JSON serialization + json_data = analysis.model_dump_json() + parsed_data = json.loads(json_data) + + # Verify Quantity objects are serialized as strings with units + assert parsed_data["end_time"] == "1000 millisecond" + assert parsed_data["dtmin"] == "0.1 millisecond" + assert parsed_data["dtmax"] == "10 millisecond" + assert parsed_data["global_damping"] == "0.5 / second" + + def test_quantity_field_serializer_dict(self): + """Test that Quantity fields are properly serialized to dict format.""" + analysis = Analysis( + end_time=Quantity(1000, "ms"), + dtmin=Quantity(0.1, "ms"), + global_damping=Quantity(0.5, "1/s"), + ) + + # Test dict serialization for YAML + dict_data = analysis.model_dump() + + # Verify Quantity objects are serialized as strings + assert dict_data["end_time"] == Quantity(1000, "ms") + assert dict_data["dtmin"] == Quantity(0.1, "ms") + assert dict_data["global_damping"] == Quantity(0.5, "1/s") + + def test_quantity_validator_from_string(self): + """Test that Quantity validators properly deserialize from strings.""" + # Test creation from string representation + analysis_data = { + "end_time": "1000.0 millisecond", + "dtmin": "0.1 millisecond", + "dtmax": "10.0 millisecond", + "global_damping": "0.5 / second", + } + + analysis = Analysis(**analysis_data) + + # Verify proper Quantity objects are created + assert isinstance(analysis.end_time, Quantity) + assert analysis.end_time.magnitude == 1000.0 + assert str(analysis.end_time.units) == "millisecond" + + assert isinstance(analysis.global_damping, Quantity) + assert analysis.global_damping.magnitude == 0.5 + assert str(analysis.global_damping.units) == "1 / second" + + def test_quantity_validator_from_string_invalid_value(self): + """Test that the validator raises an error when a non-quantity string is received.""" + with pytest.raises(ValueError): + Analysis(end_time="invalid unit string") + + def test_quantity_validator_from_quantity(self): + """Test that Quantity validators pass through existing Quantity objects.""" + # Test creation from Quantity objects directly + analysis = Analysis( + end_time=Quantity(1000, "ms"), + dtmin=Quantity(0.1, "ms"), + global_damping=Quantity(0.5, "1/s"), + ) + + # Verify Quantity objects are preserved + assert isinstance(analysis.end_time, Quantity) + assert analysis.end_time.magnitude == 1000.0 + assert str(analysis.end_time.units) == "millisecond" + + @pytest.mark.xfail(reason="Default units not yet implemented in Pydantic validators") + def test_quantity_validator_from_numeric_with_default_units(self): + """Test that numeric values get default units when specified.""" + # Create stimulation with numeric values - should get default units + stim_data = { + "t_start": 10.0, # Should become Quantity(10.0, "ms") + "period": 800.0, # Should become Quantity(800.0, "ms") + "duration": 2.0, # Should become Quantity(2.0, "ms") + } + + stim = Stimulation(**stim_data) + + # Verify proper conversion to Quantity with default units + assert isinstance(stim.t_start, Quantity) + assert stim.t_start.magnitude == 10.0 + assert str(stim.t_start.units) == "millisecond" + + @pytest.mark.xfail(reason="Targets units not yet implemented in Pydantic validators") + def test_quantity_validation_error_invalid_units(self): + """Test that invalid unit strings raise proper validation errors.""" + with pytest.raises(ValueError, match="Unable to parse quantity"): + Analysis(end_time="invalid_unit_string") + + @pytest.mark.xfail(reason="Targets dimensions not yet implemented in Pydantic validators") + def test_quantity_validation_error_incompatible_dimensions(self): + """Test validation of incompatible dimensions.""" + with pytest.raises(ValueError, match="incompatible dimensions"): + # Trying to assign length unit to time field + Analysis(end_time="100.0 meter") + + def test_nested_quantity_serialization(self): + """Test serialization of nested models with Quantity fields.""" + mechanics = Mechanics() + mechanics.analysis.end_time = Quantity(1000, "ms") + mechanics.analysis.dtmin = Quantity(0.1, "ms") + + # Test JSON serialization of nested model + json_data = mechanics.model_dump_json() + parsed_data = json.loads(json_data) + + # Verify nested Quantity serialization + assert parsed_data["analysis"]["end_time"] == "1000 millisecond" + assert parsed_data["analysis"]["dtmin"] == "0.1 millisecond" + + def test_stimulation_node_ids_validation(self): + """Test that Stimulation node_ids field is properly validated.""" + # Test with list of integers + stim = Stimulation(node_ids=[1, 2, 3]) + assert stim.node_ids == [1, 2, 3] + + # Test with None + stim_none = Stimulation(node_ids=None) + assert stim_none.node_ids is None + + # Test with invalid type + with pytest.raises(ValueError): + Stimulation(node_ids="invalid_node_ids") + + def test_complex_nested_serialization_deserialization(self): + """Test round-trip serialization/deserialization of complex nested structures.""" + # Create complex settings + settings = SimulationSettings( + mechanics=True, + electrophysiology=True, + fiber=False, + purkinje=False, + stress_free=False, + ) + + # Set some values + settings.mechanics.analysis.end_time = Quantity(1000, "ms") + settings.electrophysiology.analysis.end_time = Quantity(800, "ms") + + stim = Stimulation( + node_ids=[1, 2, 3], + t_start=Quantity(10, "ms"), + amplitude=Quantity(50, "uF/mm^3"), + ) + settings.electrophysiology.stimulation = {"apex": stim} + + # Serialize to JSON + json_data = settings.mechanics.model_dump_json() + + # Deserialize back + mechanics_dict = json.loads(json_data) + reconstructed_mechanics = Mechanics(**mechanics_dict) + + # Verify values are preserved + assert reconstructed_mechanics.analysis.end_time == Quantity(1000, "ms") + assert isinstance(reconstructed_mechanics.analysis.end_time, Quantity) + + def test_yaml_round_trip_serialization(self): + """Test YAML serialization/deserialization round trip.""" + analysis = Analysis( + end_time=Quantity(1000, "ms"), + dtmin=Quantity(0.1, "ms"), + global_damping=Quantity(0.5, "1/s"), + ) + + # Serialize to dict for YAML + data_dict = analysis.model_dump(mode="json") + + # Convert to YAML and back + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + yaml.dump(data_dict, f) + temp_path = f.name + + try: + with open(temp_path, "r") as f: + loaded_dict = yaml.load(f, Loader=yaml.SafeLoader) + + # Reconstruct object + reconstructed = Analysis(**loaded_dict) + + # Verify values match + assert reconstructed.end_time == analysis.end_time + assert reconstructed.dtmin == analysis.dtmin + assert reconstructed.global_damping == analysis.global_damping + + finally: + Path(temp_path).unlink() + + def test_unit_conversion_during_validation(self): + """Test that units are properly converted during validation.""" + # Create analysis with different time units + analysis = Analysis( + end_time="1.0 second", # Should be converted to milliseconds + dtmin="0.1 second", + ) + + # Convert to consistent unit system + analysis.to_consistent_unit_system() + + # Verify conversion to consistent units + assert analysis.end_time.magnitude == 1000.0 + assert str(analysis.end_time.units) == "millisecond" + assert analysis.dtmin.magnitude == 100.0 + assert str(analysis.dtmin.units) == "millisecond" + + def test_serialize_consistency(self): + """Test the serialize method for consistency with modern serialization.""" + analysis = Analysis( + end_time=Quantity(1000, "ms"), + dtmin=Quantity(0.1, "ms"), + global_damping=Quantity(0.5, "1/s"), + ) + + # Test modern serialization method + data_modern = analysis.model_dump(mode="json") + + # Results should be identical (both convert to strings) + assert data_modern["end_time"] == "1000 millisecond" + assert data_modern["global_damping"] == "0.5 / second" + + def test_dimensionless_quantity_handling(self): + """Test handling of dimensionless quantities.""" + from ansys.health.heart.settings.settings import Electrophysiology + + ep = Electrophysiology() + ep.lambda_ratio = Quantity(0.2, "dimensionless") + + # Test serialization + data = ep.model_dump() + assert data["lambda_ratio"] == Quantity(0.2, "dimensionless") + + # Test deserialization + reconstructed = Electrophysiology(**data) + assert reconstructed.lambda_ratio == Quantity(0.2, "dimensionless") + + def test_none_quantity_fields(self): + """Test handling of None values in optional Quantity fields.""" + # Create model with None values (if any fields allow it) + stim = Stimulation(node_ids=None) # node_ids can be None + + # Serialize and verify None is preserved + data = stim.model_dump(mode="json") + assert data["node_ids"] is None + + # Deserialize and verify None is preserved + reconstructed = Stimulation(**data) + assert reconstructed.node_ids is None + + def test_quantity_field_validation_edge_cases(self): + """Test edge cases in Quantity field validation.""" + # Test zero values + analysis = Analysis(end_time="0.0 millisecond") + assert analysis.end_time.magnitude == 0.0 + + # Test negative values where appropriate + analysis = Analysis(global_damping="-0.1 / second") + assert analysis.global_damping.magnitude == -0.1 + + # Test very small values + analysis = Analysis(dtmin="1e-6 millisecond") + assert analysis.dtmin.magnitude == 1e-6 + + def test_pydantic_validation_assignment(self): + """Test that assignment validation works with Quantity fields.""" + analysis = Analysis() + + # Test assignment of string + analysis.end_time = "500.0 millisecond" + assert isinstance(analysis.end_time, Quantity) + assert analysis.end_time.magnitude == 500.0 + + # Test assignment of Quantity + analysis.end_time = Quantity(1000, "ms") + assert analysis.end_time.magnitude == 1000.0 + + # Test invalid assignment + with pytest.raises(ValueError): + analysis.end_time = "invalid_quantity" diff --git a/tests/heart/settings/test_settings.py b/tests/heart/settings/test_settings.py index aac445c97..fe9f5196b 100644 --- a/tests/heart/settings/test_settings.py +++ b/tests/heart/settings/test_settings.py @@ -27,6 +27,7 @@ import numpy as np from pint import Quantity +from pydantic import ValidationError import pytest from ansys.health.heart.settings.defaults import fibers as fibers_defaults @@ -35,6 +36,7 @@ Fibers, SimulationSettings, Stimulation, + ZeroPressure, _get_consistent_units_str, _windows_to_wsl_path, ) @@ -80,10 +82,10 @@ " period: 800 millisecond\n" " duration: 20 millisecond\n" " amplitude: 50 microfarad / millimeter ** 3\n" - " _layers:\n" + " layers:\n" " percent_endo: 0.17 dimensionless\n" " percent_mid: 0.41 dimensionless\n" - " _lambda: 0.2 dimensionless\n" + " lambda_ratio: 0.2 dimensionless\n" ) @@ -124,7 +126,12 @@ def test_settings_save_002(): purkinje=False, stress_free=False, ) - stim = Stimulation(t_start=0, period=800, duration=20, amplitude=50) + stim = Stimulation( + t_start=Quantity(0, "ms"), + period=Quantity(800, "ms"), + duration=Quantity(20, "ms"), + amplitude=Quantity(50, "uF/mm^3"), + ) settings.electrophysiology.stimulation = {"stimdefaults": stim} # fill some dummy data @@ -142,7 +149,13 @@ def test_settings_save_002(): compare_string_with_file(REF_STRING_SETTINGS_YML_EP, file_path) settings.load_defaults() - stim2 = Stimulation(node_ids=[1, 2, 3], t_start=10, period=100, duration=30, amplitude=40) + stim2 = Stimulation( + node_ids=[1, 2, 3], + t_start=Quantity(10, "ms"), + period=Quantity(100, "ms"), + duration=Quantity(30, "ms"), + amplitude=Quantity(40, "uF/mm^3"), + ) settings.electrophysiology.stimulation["stim2"] = stim2 stim: Stimulation = settings.electrophysiology.stimulation["stim2"] @@ -239,9 +252,10 @@ def test_convert_units_002(): def test_settings_set_defaults(): - """Check if defaults properly set.""" - settings = Fibers() - settings.set_values(fibers_defaults.angles) + """Check if defaults properly set using Pydantic model initialization.""" + # Create Fibers instance with defaults applied directly + fibers_data = fibers_defaults.angles + settings = Fibers(**fibers_data) assert settings.alpha_endo.m == -60 @@ -309,3 +323,257 @@ def test_windows_path_to_wsl_path(): _windows_to_wsl_path("\\\\wsl.localhost\\Ubuntu\\home\\user\\project") == "/home/user/project" ) + + +# ZeroPressure test reference strings +REF_STRING_ZERO_PRESSURE_YML = ( + "Simulation Settings:\n" + " stress_free:\n" + " analysis:\n" + " end_time: 500 millisecond\n" + " dtmin: 5 millisecond\n" + " dtmax: 50 millisecond\n" + " dt_d3plot: 25 millisecond\n" + " dt_icvout: 10 millisecond\n" + " global_damping: 0.1 / second\n" + " stiffness_damping: 0.05 second\n" + " dt_nodout: 15 millisecond\n" + " max_iters: 5\n" + " method: 1\n" + " tolerance: 1.0\n" +) + + +def test_zero_pressure_serialization_yaml(): + """Test YAML serialization of ZeroPressure settings.""" + settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + + # Set custom values + settings.stress_free.analysis.end_time = Quantity(500, "ms") + settings.stress_free.analysis.dtmin = Quantity(5, "ms") + settings.stress_free.analysis.dtmax = Quantity(50, "ms") + settings.stress_free.analysis.dt_d3plot = Quantity(25, "ms") + settings.stress_free.analysis.dt_icvout = Quantity(10, "ms") + settings.stress_free.analysis.dt_nodout = Quantity(15, "ms") + settings.stress_free.analysis.global_damping = Quantity(0.1, "1/s") + settings.stress_free.analysis.stiffness_damping = Quantity(0.05, "s") + settings.stress_free.analysis.max_iters = 5 + settings.stress_free.analysis.method = 1 + settings.stress_free.analysis.tolerance = 1.0 + + with tempfile.TemporaryDirectory(prefix=".pyansys-heart") as tempdir: + file_path = os.path.join(tempdir, "zero_pressure.yml") + settings.save(file_path) + + # Read file contents and compare manually + with open(file_path, "r") as f: + content = f.read() + + # Verify key content exists + assert "stress_free:" in content + assert "analysis:" in content + assert "end_time: 500 millisecond" in content + assert "max_iters: 5" in content + assert "method: 1" in content + assert "tolerance: 1.0" in content + + +def test_zero_pressure_deserialization_yaml(): + """Test YAML deserialization of ZeroPressure settings.""" + with tempfile.TemporaryDirectory(prefix=".pyansys-heart") as tempdir: + file_path = os.path.join(tempdir, "zero_pressure_load.yml") + + # Write reference string to file + with open(file_path, "w") as f: + f.write(REF_STRING_ZERO_PRESSURE_YML) + + # Load settings + settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + settings.load(file_path) + + # Verify loaded values + assert settings.stress_free.analysis.end_time == Quantity(500, "ms") + assert settings.stress_free.analysis.dtmin == Quantity(5, "ms") + assert settings.stress_free.analysis.dtmax == Quantity(50, "ms") + assert settings.stress_free.analysis.dt_d3plot == Quantity(25, "ms") + assert settings.stress_free.analysis.dt_icvout == Quantity(10, "ms") + assert settings.stress_free.analysis.dt_nodout == Quantity(15, "ms") + assert settings.stress_free.analysis.global_damping == Quantity(0.1, "1/s") + assert settings.stress_free.analysis.stiffness_damping == Quantity(0.05, "s") + assert settings.stress_free.analysis.max_iters == 5 + assert settings.stress_free.analysis.method == 1 + assert settings.stress_free.analysis.tolerance == 1.0 + + +def test_zero_pressure_serialization_json(): + """Test JSON serialization of ZeroPressure settings.""" + settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + + # Load defaults + settings.load_defaults() + + with tempfile.TemporaryDirectory(prefix=".pyansys-heart") as tempdir: + file_path = os.path.join(tempdir, "zero_pressure.json") + settings.save(file_path) + + # Read and parse JSON + with open(file_path, "r") as f: + content = f.read() + import json + + data = json.loads(content) + + # Verify structure + assert "Simulation Settings" in data + assert "stress_free" in data["Simulation Settings"] + assert "analysis" in data["Simulation Settings"]["stress_free"] + + analysis = data["Simulation Settings"]["stress_free"]["analysis"] + assert analysis["end_time"] == "1000.0 millisecond" + assert analysis["max_iters"] == 3 + assert analysis["method"] == 2 + assert analysis["tolerance"] == 5.0 + + +def test_zero_pressure_roundtrip(): + """Test roundtrip serialization/deserialization of ZeroPressure.""" + # Create settings with custom values + original_settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + + # Set specific values + original_settings.stress_free.analysis.end_time = Quantity(750, "ms") + original_settings.stress_free.analysis.max_iters = 7 + original_settings.stress_free.analysis.method = 3 + original_settings.stress_free.analysis.tolerance = 2.5 + original_settings.stress_free.analysis.dt_nodout = Quantity(100, "ms") + + with tempfile.TemporaryDirectory(prefix=".pyansys-heart") as tempdir: + file_path = os.path.join(tempdir, "roundtrip.yml") + + # Save settings + original_settings.save(file_path) + + # Load settings into new object + loaded_settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + loaded_settings.load(file_path) + + # Verify roundtrip consistency + assert ( + loaded_settings.stress_free.analysis.end_time + == original_settings.stress_free.analysis.end_time + ) + assert ( + loaded_settings.stress_free.analysis.max_iters + == original_settings.stress_free.analysis.max_iters + ) + assert ( + loaded_settings.stress_free.analysis.method + == original_settings.stress_free.analysis.method + ) + assert ( + loaded_settings.stress_free.analysis.tolerance + == original_settings.stress_free.analysis.tolerance + ) + assert ( + loaded_settings.stress_free.analysis.dt_nodout + == original_settings.stress_free.analysis.dt_nodout + ) + + +def test_zero_pressure_unit_conversion(): + """Test unit conversion for ZeroPressure settings.""" + zero_pressure = ZeroPressure() + + # Set values with different units + zero_pressure.analysis.end_time = Quantity(2, "s") # Will convert to ms + zero_pressure.analysis.dtmin = Quantity(0.01, "s") # Will convert to ms + zero_pressure.analysis.global_damping = Quantity(2, "1/s") # Should stay 1/s + + # Apply unit conversion + zero_pressure.to_consistent_unit_system() + + # Verify conversions + assert zero_pressure.analysis.end_time.magnitude == 2000.0 + assert str(zero_pressure.analysis.end_time.units) == "millisecond" + assert zero_pressure.analysis.dtmin.magnitude == 10.0 + assert str(zero_pressure.analysis.dtmin.units) == "millisecond" + assert zero_pressure.analysis.global_damping.magnitude == 0.002 + assert str(zero_pressure.analysis.global_damping.units) == "1 / millisecond" + + +def test_zero_pressure_validation(): + """Test Pydantic validation for ZeroPressure fields.""" + zero_pressure = ZeroPressure() + + # Test invalid max_iters (should be int) + with pytest.raises(ValidationError): + zero_pressure.analysis.max_iters = "invalid" + + # Test invalid tolerance (should be float) + with pytest.raises(ValidationError): + zero_pressure.analysis.tolerance = "invalid" + + # Test valid updates + zero_pressure.analysis.max_iters = 15 + zero_pressure.analysis.tolerance = 3.14 + zero_pressure.analysis.method = 5 + + assert zero_pressure.analysis.max_iters == 15 + assert zero_pressure.analysis.tolerance == 3.14 + assert zero_pressure.analysis.method == 5 + + +def test_zero_pressure_defaults_loading(): + """Test loading default values for ZeroPressure from defaults module.""" + settings = SimulationSettings( + mechanics=False, + electrophysiology=False, + fiber=False, + purkinje=False, + stress_free=True, + ) + + # Load defaults + settings.load_defaults() + + # Verify default values from zeropressure defaults module + assert settings.stress_free.analysis.end_time == Quantity(1000, "ms") + assert settings.stress_free.analysis.dtmin == Quantity(10, "ms") + assert settings.stress_free.analysis.dtmax == Quantity(100, "ms") + assert settings.stress_free.analysis.dt_d3plot == Quantity(100, "ms") + assert settings.stress_free.analysis.dt_nodout == Quantity(200, "ms") + + # Verify base class defaults are preserved + assert settings.stress_free.analysis.max_iters == 3 + assert settings.stress_free.analysis.method == 2 + assert settings.stress_free.analysis.tolerance == 5.0 diff --git a/tests/heart/writer/test_dynawriter.py b/tests/heart/writer/test_dynawriter.py index 33413e105..5e00f6583 100644 --- a/tests/heart/writer/test_dynawriter.py +++ b/tests/heart/writer/test_dynawriter.py @@ -129,7 +129,7 @@ def test_add_stimulation_keyword(_mock_model, solvertype, expected_kw): settings.load_defaults() settings.electrophysiology.analysis.solvertype = solvertype # set up stimulation - stimulation = Stimulation([1, 2]) + stimulation = Stimulation(node_ids=[1, 2]) writer = writers.ElectroMechanicsDynaWriter(model, settings)