@@ -523,8 +523,7 @@ def _sigmays(
523523 in_axes = (0 , 0 , None , None , 0 , 0 , 0 , 0 ),
524524 )(ts , xs , p , tcl , hs , iys , ops , nps )
525525
526- @eqx .filter_jit
527- def simulate_condition (
526+ def simulate_condition_unjitted (
528527 self ,
529528 p : jt .Float [jt .Array , "np" ] | None ,
530529 ts_dyn : jt .Float [jt .Array , "nt_dyn" ],
@@ -550,57 +549,10 @@ def simulate_condition(
550549 ts_mask : jt .Bool [jt .Array , "nt" ] = jnp .array ([]),
551550 ret : ReturnValue = ReturnValue .llh ,
552551 ) -> tuple [jt .Float [jt .Array , "nt *nx" ] | jnp .float_ , dict ]:
553- r"""
554- Simulate a condition.
555-
556- :param p:
557- parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
558- the values stored in :attr:`parameters` are used.
559- :param ts_dyn:
560- time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
561- allowed to facilitate the evaluation of multiple observables at specific time points.
562- :param ts_posteq:
563- time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
564- the number of observables that are evaluated after post-equilibration.
565- :param my:
566- observed data
567- :param iys:
568- indices of the observables according to ordering in :ivar observable_ids:
569- :param iy_trafos:
570- indices of transformations for observables
571- :param ops:
572- observable parameters
573- :param nps:
574- noise parameters
575- :param solver:
576- ODE solver
577- :param controller:
578- step size controller
579- :param adjoint:
580- adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued
581- outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs).
582- :param steady_state_event:
583- event function for steady state. See :func:`diffrax.steady_state_event` for details.
584- :param max_steps:
585- maximum number of solver steps
586- :param x_preeq:
587- initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
588- :meth:`_x0`.
589- :param mask_reinit:
590- mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
591- :param x_reinit:
592- re-initialized state vector. If not provided, the state vector is not re-initialized.
593- :param init_override:
594- override model input e.g. with neural net outputs. If not provided, the inputs are not overridden.
595- :param init_override_mask:
596- mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`.
597- :param ts_mask:
598- mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
599- the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
600- :param ret:
601- which output to return. See :class:`ReturnValue` for available options.
602- :return:
603- output according to `ret` and general results/statistics
552+ """
553+ Unjitted version of simulate_condition for type checking with beartype.
554+
555+ See :meth:`simulate_condition` for full documentation.
604556 """
605557 t0 = 0.0
606558 if p is None :
@@ -736,6 +688,112 @@ def simulate_condition(
736688
737689 return output , stats
738690
691+ @eqx .filter_jit
692+ def simulate_condition (
693+ self ,
694+ p : jt .Float [jt .Array , "np" ] | None ,
695+ ts_dyn : jt .Float [jt .Array , "nt_dyn" ],
696+ ts_posteq : jt .Float [jt .Array , "nt_posteq" ],
697+ my : jt .Float [jt .Array , "nt" ],
698+ iys : jt .Int [jt .Array , "nt" ],
699+ iy_trafos : jt .Int [jt .Array , "nt" ],
700+ ops : jt .Float [jt .Array , "nt *nop" ],
701+ nps : jt .Float [jt .Array , "nt *nnp" ],
702+ solver : diffrax .AbstractSolver ,
703+ controller : diffrax .AbstractStepSizeController ,
704+ root_finder : AbstractRootFinder ,
705+ adjoint : diffrax .AbstractAdjoint ,
706+ steady_state_event : Callable [
707+ ..., diffrax ._custom_types .BoolScalarLike
708+ ],
709+ max_steps : int | jnp .int_ ,
710+ x_preeq : jt .Float [jt .Array , "*nx" ] = jnp .array ([]),
711+ mask_reinit : jt .Bool [jt .Array , "*nx" ] = jnp .array ([]),
712+ x_reinit : jt .Float [jt .Array , "*nx" ] = jnp .array ([]),
713+ init_override : jt .Float [jt .Array , "*nx" ] = jnp .array ([]),
714+ init_override_mask : jt .Bool [jt .Array , "*nx" ] = jnp .array ([]),
715+ ts_mask : jt .Bool [jt .Array , "nt" ] = jnp .array ([]),
716+ ret : ReturnValue = ReturnValue .llh ,
717+ ) -> tuple [jt .Float [jt .Array , "nt *nx" ] | jnp .float_ , dict ]:
718+ r"""
719+ Simulate a condition (JIT-compiled version).
720+
721+ This is the JIT-compiled version for optimal performance. For runtime type checking
722+ with beartype, use :meth:`simulate_condition_unjitted` instead.
723+
724+ :param p:
725+ parameters for simulation ordered according to ids in :ivar parameter_ids:. If ``None``,
726+ the values stored in :attr:`parameters` are used.
727+ :param ts_dyn:
728+ time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
729+ allowed to facilitate the evaluation of multiple observables at specific time points.
730+ :param ts_posteq:
731+ time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
732+ the number of observables that are evaluated after post-equilibration.
733+ :param my:
734+ observed data
735+ :param iys:
736+ indices of the observables according to ordering in :ivar observable_ids:
737+ :param iy_trafos:
738+ indices of transformations for observables
739+ :param ops:
740+ observable parameters
741+ :param nps:
742+ noise parameters
743+ :param solver:
744+ ODE solver
745+ :param controller:
746+ step size controller
747+ :param adjoint:
748+ adjoint method. Recommended values are `diffrax.DirectAdjoint()` for jax.jacfwd (with vector-valued
749+ outputs) and `diffrax.RecursiveCheckpointAdjoint()` for jax.grad (for scalar-valued outputs).
750+ :param steady_state_event:
751+ event function for steady state. See :func:`diffrax.steady_state_event` for details.
752+ :param max_steps:
753+ maximum number of solver steps
754+ :param x_preeq:
755+ initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
756+ :meth:`_x0`.
757+ :param mask_reinit:
758+ mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
759+ :param x_reinit:
760+ re-initialized state vector. If not provided, the state vector is not re-initialized.
761+ :param init_override:
762+ override model input e.g. with neural net outputs. If not provided, the inputs are not overridden.
763+ :param init_override_mask:
764+ mask for input override. If `True`, the corresponding input is replaced with the corresponding value from `init_override`.
765+ :param ts_mask:
766+ mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
767+ the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
768+ :param ret:
769+ which output to return. See :class:`ReturnValue` for available options.
770+ :return:
771+ output according to `ret` and general results/statistics
772+ """
773+ return self .simulate_condition_unjitted (
774+ p ,
775+ ts_dyn ,
776+ ts_posteq ,
777+ my ,
778+ iys ,
779+ iy_trafos ,
780+ ops ,
781+ nps ,
782+ solver ,
783+ controller ,
784+ root_finder ,
785+ adjoint ,
786+ steady_state_event ,
787+ max_steps ,
788+ x_preeq ,
789+ mask_reinit ,
790+ x_reinit ,
791+ init_override ,
792+ init_override_mask ,
793+ ts_mask ,
794+ ret ,
795+ )
796+
739797 @eqx .filter_jit
740798 def preequilibrate_condition (
741799 self ,
0 commit comments