|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import contextlib |
6 | 5 | import copy |
7 | 6 | import itertools |
8 | 7 | import logging |
|
47 | 46 | from .exporters.sundials.cxxcodeprinter import csc_matrix |
48 | 47 | from .importers.utils import ( |
49 | 48 | ObservableTransformation, |
50 | | - SBMLException, |
51 | 49 | _default_simplify, |
52 | 50 | amici_time_symbol, |
53 | 51 | smart_subs_dict, |
@@ -404,165 +402,6 @@ def states(self) -> list[State]: |
404 | 402 | """Get all states.""" |
405 | 403 | return self._differential_states + self._algebraic_states |
406 | 404 |
|
407 | | - def _process_sbml_rate_of(self) -> None: |
408 | | - """Substitute any SBML-rateOf constructs in the model equations""" |
409 | | - from sbmlmath import rate_of as rate_of_func |
410 | | - |
411 | | - species_sym_to_xdot = dict( |
412 | | - zip(self.sym("x"), self.sym("xdot"), strict=True) |
413 | | - ) |
414 | | - species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))} |
415 | | - |
416 | | - def get_rate(symbol: sp.Symbol): |
417 | | - """Get rate of change of the given symbol""" |
418 | | - if symbol.find(rate_of_func): |
419 | | - raise SBMLException("Nesting rateOf() is not allowed.") |
420 | | - |
421 | | - # Replace all rateOf(some_species) by their respective xdot equation |
422 | | - with contextlib.suppress(KeyError): |
423 | | - return self._eqs["xdot"][species_sym_to_idx[symbol]] |
424 | | - |
425 | | - # For anything other than a state, rateOf(.) is 0 or invalid |
426 | | - return 0 |
427 | | - |
428 | | - # replace rateOf-instances in xdot by xdot symbols |
429 | | - made_substitutions = False |
430 | | - for i_state in range(len(self.eq("xdot"))): |
431 | | - if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): |
432 | | - self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( |
433 | | - { |
434 | | - # either the rateOf argument is a state, or it's 0 |
435 | | - rate_of: species_sym_to_xdot.get(rate_of.args[0], 0) |
436 | | - for rate_of in rate_ofs |
437 | | - } |
438 | | - ) |
439 | | - made_substitutions = True |
440 | | - |
441 | | - if made_substitutions: |
442 | | - # substitute in topological order |
443 | | - subs = toposort_symbols( |
444 | | - dict(zip(self.sym("xdot"), self.eq("xdot"), strict=True)) |
445 | | - ) |
446 | | - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) |
447 | | - |
448 | | - # replace rateOf-instances in w by xdot equation |
449 | | - # here we may need toposort, as xdot may depend on w |
450 | | - made_substitutions = False |
451 | | - for i_expr in range(len(self.eq("w"))): |
452 | | - new, replacement = self._eqs["w"][i_expr].replace( |
453 | | - rate_of_func, get_rate, map=True |
454 | | - ) |
455 | | - if replacement: |
456 | | - self._eqs["w"][i_expr] = new |
457 | | - made_substitutions = True |
458 | | - |
459 | | - if made_substitutions: |
460 | | - # Sort expressions in self._expressions, w symbols, and w equations |
461 | | - # in topological order. Ideally, this would already happen before |
462 | | - # adding the expressions to the model, but at that point, we don't |
463 | | - # have access to xdot yet. |
464 | | - # NOTE: elsewhere, conservations law expressions are expected to |
465 | | - # occur before any other w expressions, so we must maintain their |
466 | | - # position |
467 | | - # toposort everything but conservation law expressions, |
468 | | - # then prepend conservation laws |
469 | | - w_sorted = toposort_symbols( |
470 | | - dict( |
471 | | - zip( |
472 | | - self.sym("w")[self.num_cons_law() :, :], |
473 | | - self.eq("w")[self.num_cons_law() :, :], |
474 | | - strict=True, |
475 | | - ) |
476 | | - ) |
477 | | - ) |
478 | | - w_sorted = ( |
479 | | - dict( |
480 | | - zip( |
481 | | - self.sym("w")[: self.num_cons_law(), :], |
482 | | - self.eq("w")[: self.num_cons_law(), :], |
483 | | - strict=True, |
484 | | - ) |
485 | | - ) |
486 | | - | w_sorted |
487 | | - ) |
488 | | - old_syms = tuple(self._syms["w"]) |
489 | | - topo_expr_syms = tuple(w_sorted.keys()) |
490 | | - new_order = [old_syms.index(s) for s in topo_expr_syms] |
491 | | - self._expressions = [self._expressions[i] for i in new_order] |
492 | | - self._syms["w"] = sp.Matrix(topo_expr_syms) |
493 | | - self._eqs["w"] = sp.Matrix(list(w_sorted.values())) |
494 | | - |
495 | | - # replace rateOf-instances in x0 by xdot equation |
496 | | - # indices of state variables whose x0 was modified |
497 | | - changed_indices = [] |
498 | | - for i_state in range(len(self.eq("x0"))): |
499 | | - new, replacement = self._eqs["x0"][i_state].replace( |
500 | | - rate_of_func, get_rate, map=True |
501 | | - ) |
502 | | - if replacement: |
503 | | - self._eqs["x0"][i_state] = new |
504 | | - changed_indices.append(i_state) |
505 | | - if changed_indices: |
506 | | - # Replace any newly introduced state variables |
507 | | - # by their x0 expressions. |
508 | | - # Also replace any newly introduced `w` symbols by their |
509 | | - # expressions (after `w` was toposorted above). |
510 | | - subs = toposort_symbols( |
511 | | - dict(zip(self.sym("x_rdata"), self.eq("x0"), strict=True)) |
512 | | - ) |
513 | | - subs = dict(zip(self._syms["w"], self.eq("w"), strict=True)) | subs |
514 | | - for i_state in changed_indices: |
515 | | - self._eqs["x0"][i_state] = smart_subs_dict( |
516 | | - self._eqs["x0"][i_state], subs |
517 | | - ) |
518 | | - |
519 | | - for component in chain( |
520 | | - self.observables(), |
521 | | - self.events(), |
522 | | - self._algebraic_equations, |
523 | | - ): |
524 | | - if rate_ofs := component.get_val().find(rate_of_func): |
525 | | - if isinstance(component, Event): |
526 | | - # TODO froot(...) can currently not depend on `w`, so this substitution fails for non-zero rates |
527 | | - # see, e.g., sbml test case 01293 |
528 | | - raise SBMLException( |
529 | | - "AMICI does currently not support rateOf(.) inside event trigger functions." |
530 | | - ) |
531 | | - |
532 | | - if isinstance(component, AlgebraicEquation): |
533 | | - # TODO IDACalcIC fails with |
534 | | - # "The linesearch algorithm failed: step too small or too many backtracks." |
535 | | - # see, e.g., sbml test case 01482 |
536 | | - raise SBMLException( |
537 | | - "AMICI does currently not support rateOf(.) inside AlgebraicRules." |
538 | | - ) |
539 | | - |
540 | | - component.set_val( |
541 | | - component.get_val().subs( |
542 | | - { |
543 | | - rate_of: get_rate(rate_of.args[0]) |
544 | | - for rate_of in rate_ofs |
545 | | - } |
546 | | - ) |
547 | | - ) |
548 | | - |
549 | | - for event in self.events(): |
550 | | - state_update = event.get_state_update( |
551 | | - x=self.sym("x"), x_old=self.sym("x") |
552 | | - ) |
553 | | - if state_update is None: |
554 | | - continue |
555 | | - |
556 | | - for i_state in range(len(state_update)): |
557 | | - if rate_ofs := state_update[i_state].find(rate_of_func): |
558 | | - raise SBMLException( |
559 | | - "AMICI does currently not support rateOf(.) inside event state updates." |
560 | | - ) |
561 | | - # TODO here we need xdot sym, not eqs |
562 | | - # event._state_update[i_state] = event._state_update[i_state].subs( |
563 | | - # {rate_of: get_rate(rate_of.args[0]) for rate_of in rate_ofs} |
564 | | - # ) |
565 | | - |
566 | 405 | def add_component( |
567 | 406 | self, component: ModelQuantity, insert_first: bool | None = False |
568 | 407 | ) -> None: |
@@ -2752,3 +2591,37 @@ def has_event_assignments(self) -> bool: |
2752 | 2591 | boolean indicating if event assignments are present |
2753 | 2592 | """ |
2754 | 2593 | return any(event.updates_state for event in self._events) |
| 2594 | + |
| 2595 | + def toposort_expressions(self) -> dict[sp.Symbol, sp.Expr]: |
| 2596 | + """ |
| 2597 | + Sort expressions in topological order. |
| 2598 | +
|
| 2599 | + :return: |
| 2600 | + dict of expression symbols to expressions in topological order |
| 2601 | + """ |
| 2602 | + # NOTE: elsewhere, conservations law expressions are expected to |
| 2603 | + # occur before any other w expressions, so we must maintain their |
| 2604 | + # position. |
| 2605 | + # toposort everything but conservation law expressions, |
| 2606 | + # then prepend conservation laws |
| 2607 | + if self._syms or self._eqs: |
| 2608 | + raise AssertionError( |
| 2609 | + "This function must be called before generating any symbols " |
| 2610 | + "or equations." |
| 2611 | + ) |
| 2612 | + w_toposorted = toposort_symbols( |
| 2613 | + { |
| 2614 | + e.get_sym(): e.get_val() |
| 2615 | + for e in self.expressions()[self.num_cons_law() :] |
| 2616 | + } |
| 2617 | + ) |
| 2618 | + |
| 2619 | + w_toposorted = { |
| 2620 | + e.get_sym(): e.get_val() |
| 2621 | + for e in self.expressions()[: self.num_cons_law()] |
| 2622 | + } | w_toposorted |
| 2623 | + old_syms = [e.get_sym() for e in self.expressions()] |
| 2624 | + self._expressions = [ |
| 2625 | + self._expressions[old_syms.index(s)] for s in w_toposorted |
| 2626 | + ] |
| 2627 | + return w_toposorted |
0 commit comments