55from abc import abstractmethod
66from pathlib import Path
77import enum
8+ from dataclasses import field
89
910import diffrax
1011import equinox as eqx
@@ -43,9 +44,10 @@ class JAXModel(eqx.Module):
4344 Path to the JAX model file.
4445 """
4546
46- MODEL_API_VERSION = "0.0.3 "
47+ MODEL_API_VERSION = "0.0.4 "
4748 api_version : str
4849 jax_py_file : Path
50+ parameters : jnp .ndarray = field (default_factory = lambda : jnp .array ([]))
4951
5052 def __init__ (self ):
5153 if self .api_version != self .MODEL_API_VERSION :
@@ -93,11 +95,16 @@ def _w(
9395 ...
9496
9597 @abstractmethod
96- def _x0 (self , p : jt .Float [jt .Array , "np" ]) -> jt .Float [jt .Array , "nx" ]:
98+ def _x0 (
99+ self , t : jnp .float_ , p : jt .Float [jt .Array , "np" ]
100+ ) -> jt .Float [jt .Array , "nx" ]:
97101 """
98102 Compute the initial state vector.
99103
104+ :param t: initial time point
100105 :param p: parameters
106+ :return:
107+ Initial state vector.
101108 """
102109 ...
103110
@@ -264,6 +271,17 @@ def parameter_ids(self) -> list[str]:
264271 """
265272 ...
266273
274+ @property
275+ @abstractmethod
276+ def expression_ids (self ) -> list [str ]:
277+ """
278+ Get the expression ids of the model.
279+
280+ :return:
281+ Expression ids
282+ """
283+ ...
284+
267285 def _eq (
268286 self ,
269287 p : jt .Float [jt .Array , "np" ],
@@ -496,7 +514,7 @@ def _sigmays(
496514 @eqx .filter_jit
497515 def simulate_condition (
498516 self ,
499- p : jt .Float [jt .Array , "np" ],
517+ p : jt .Float [jt .Array , "np" ] | None ,
500518 ts_dyn : jt .Float [jt .Array , "nt_dyn" ],
501519 ts_posteq : jt .Float [jt .Array , "nt_posteq" ],
502520 my : jt .Float [jt .Array , "nt" ],
@@ -521,7 +539,8 @@ def simulate_condition(
521539 Simulate a condition.
522540
523541 :param p:
524- parameters for simulation ordered according to ids in :ivar parameter_ids:
542+ parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
543+ the values stored in :attr:`parameters` are used.
525544 :param ts_dyn:
526545 time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
527546 allowed to facilitate the evaluation of multiple observables at specific time points.
@@ -564,10 +583,13 @@ def simulate_condition(
564583 :return:
565584 output according to `ret` and general results/statistics
566585 """
586+ if p is None :
587+ p = self .parameters
588+
567589 if x_preeq .shape [0 ]:
568590 x = x_preeq
569591 else :
570- x = self ._x0 (p )
592+ x = self ._x0 (0.0 , p )
571593
572594 if not ts_mask .shape [0 ]:
573595 ts_mask = jnp .ones_like (my , dtype = jnp .bool_ )
@@ -675,7 +697,7 @@ def simulate_condition(
675697 @eqx .filter_jit
676698 def preequilibrate_condition (
677699 self ,
678- p : jt .Float [jt .Array , "np" ],
700+ p : jt .Float [jt .Array , "np" ] | None ,
679701 x_reinit : jt .Float [jt .Array , "*nx" ],
680702 mask_reinit : jt .Bool [jt .Array , "*nx" ],
681703 solver : diffrax .AbstractSolver ,
@@ -689,7 +711,8 @@ def preequilibrate_condition(
689711 Simulate a condition.
690712
691713 :param p:
692- parameters for simulation ordered according to ids in :ivar parameter_ids:
714+ parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
715+ the values stored in :attr:`parameters` are used.
693716 :param x_reinit:
694717 re-initialized state vector. If not provided, the state vector is not re-initialized.
695718 :param mask_reinit:
@@ -704,7 +727,10 @@ def preequilibrate_condition(
704727 pre-equilibrated state variables and statistics
705728 """
706729 # Pre-equilibration
707- x0 = self ._x0 (p )
730+ if p is None :
731+ p = self .parameters
732+
733+ x0 = self ._x0 (0.0 , p )
708734 if x_reinit .shape [0 ]:
709735 x0 = jnp .where (mask_reinit , x_reinit , x0 )
710736 tcl = self ._tcl (x0 , p )
0 commit comments