|
25 | 25 | from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter, _jnp_array_str |
26 | 26 | from amici.jax.model import JAXModel |
27 | 27 | from amici.de_model import DEModel |
| 28 | + |
28 | 29 | from amici.de_export import is_valid_identifier |
29 | 30 | from amici.import_utils import ( |
30 | 31 | strip_pysb, |
@@ -142,14 +143,18 @@ def __init__( |
142 | 143 | """ |
143 | 144 | set_log_level(logger, verbose) |
144 | 145 |
|
145 | | - if any(event.updates_state for event in ode_model._events): |
| 146 | + if ode_model.has_event_assignments(): |
146 | 147 | raise NotImplementedError( |
147 | 148 | "The JAX backend does not support models with event assignments." |
148 | 149 | ) |
149 | 150 |
|
150 | | - if ode_model._algebraic_equations: |
| 151 | + if ode_model.has_algebraic_states(): |
| 152 | + raise NotImplementedError( |
| 153 | + "The JAX backend does not support models with algebraic states." |
| 154 | + ) |
| 155 | + if ode_model.has_parameter_dependent_implicit_roots(): |
151 | 156 | raise NotImplementedError( |
152 | | - "The JAX backend does not support models with algebraic equations." |
| 157 | + "The JAX backend does not support models with parameter dependent implicit event triggers." |
153 | 158 | ) |
154 | 159 |
|
155 | 160 | self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG |
@@ -243,6 +248,9 @@ def _generate_jax_code(self) -> None: |
243 | 248 | # tuple of variable names (ids as they are unique) |
244 | 249 | **_jax_variable_ids(self.model, ("p", "k", "y", "w", "x_rdata")), |
245 | 250 | "P_VALUES": _jnp_array_str(self.model.val("p")), |
| 251 | + "ROOTS": _jnp_array_str( |
| 252 | + {root for e in self.model._events for root in e.get_trigger_times()} |
| 253 | + ), |
246 | 254 | **{ |
247 | 255 | "MODEL_NAME": self.model_name, |
248 | 256 | # keep track of the API version that the model was generated with so we |
|
0 commit comments