|
1 | 1 | using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization, |
2 | 2 | OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase, |
3 | | - SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, Test |
| 3 | + SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, |
| 4 | + DiffEqCallbacks, Test |
4 | 5 | using ModelingToolkit: t_nounits as t, D_nounits as D |
5 | 6 |
|
6 | 7 | # Sets rnd number. |
@@ -528,3 +529,338 @@ end |
528 | 529 | @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w]) |
529 | 530 | @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y]) |
530 | 531 | end |
| 532 | + |
| 533 | +@testset "Discrete save indexing" begin |
| 534 | + struct NumSymbolCache{S} |
| 535 | + sc::S |
| 536 | + end |
| 537 | + SymbolicIndexingInterface.symbolic_container(s::NumSymbolCache) = s.sc |
| 538 | + function SymbolicIndexingInterface.is_observed(s::NumSymbolCache, x) |
| 539 | + return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) && |
| 540 | + !is_parameter(s, x) && !is_independent_variable(s, x) |
| 541 | + end |
| 542 | + function SymbolicIndexingInterface.observed(s::NumSymbolCache, x) |
| 543 | + res = ModelingToolkit.build_function(x, |
| 544 | + sort(variable_symbols(s); by = Base.Fix1(variable_index, s)), |
| 545 | + sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)), |
| 546 | + independent_variable_symbols(s)[]; expression = Val(false)) |
| 547 | + if res isa Tuple |
| 548 | + return let oopfn = res[1], iipfn = res[2] |
| 549 | + fn(out, u, p, t) = iipfn(out, u, p, t) |
| 550 | + fn(u, p, t) = oopfn(u, p, t) |
| 551 | + fn |
| 552 | + end |
| 553 | + else |
| 554 | + return res |
| 555 | + end |
| 556 | + end |
| 557 | + function SymbolicIndexingInterface.parameter_observed(s::NumSymbolCache, x) |
| 558 | + res = ModelingToolkit.build_function(x, |
| 559 | + sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)), |
| 560 | + independent_variable_symbols(s)[]; expression = Val(false)) |
| 561 | + if res isa Tuple |
| 562 | + return let oopfn = res[1], iipfn = res[2] |
| 563 | + fn(out, p, t) = iipfn(out, p, t) |
| 564 | + fn(p, t) = oopfn(p, t) |
| 565 | + fn |
| 566 | + end |
| 567 | + else |
| 568 | + return res |
| 569 | + end |
| 570 | + end |
| 571 | + function SymbolicIndexingInterface.get_all_timeseries_indexes(s::NumSymbolCache, x) |
| 572 | + if symbolic_type(x) == NotSymbolic() |
| 573 | + x = ModelingToolkit.unwrap.(x) |
| 574 | + else |
| 575 | + x = ModelingToolkit.unwrap(x) |
| 576 | + end |
| 577 | + vars = ModelingToolkit.vars(x) |
| 578 | + return mapreduce(union, vars; init = Set()) do sym |
| 579 | + if is_variable(s, sym) |
| 580 | + Set([ContinuousTimeseries()]) |
| 581 | + elseif is_parameter(s, sym) && is_timeseries_parameter(s, sym) |
| 582 | + Set([timeseries_parameter_index(s, sym).timeseries_idx]) |
| 583 | + else |
| 584 | + Set() |
| 585 | + end |
| 586 | + end |
| 587 | + end |
| 588 | + function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( |
| 589 | + ::NumSymbolCache, p::Vector{Float64}, args...) |
| 590 | + for (idx, buf) in args |
| 591 | + if idx == 1 |
| 592 | + p[1:2] .= buf |
| 593 | + else |
| 594 | + p[3:4] .= buf |
| 595 | + end |
| 596 | + end |
| 597 | + |
| 598 | + return p |
| 599 | + end |
| 600 | + function SciMLBase.create_parameter_timeseries_collection(s::NumSymbolCache, ps, tspan) |
| 601 | + trem = rem(tspan[1], 0.1, RoundDown) |
| 602 | + if trem > 0 |
| 603 | + trem = 0.1 - trem |
| 604 | + end |
| 605 | + dea1 = DiffEqArray(Vector{Float64}[], (tspan[1] + trem):0.1:tspan[2]) |
| 606 | + dea2 = DiffEqArray(Vector{Float64}[], Float64[]) |
| 607 | + return ParameterTimeseriesCollection((dea1, dea2), deepcopy(ps)) |
| 608 | + end |
| 609 | + function SciMLBase.get_saveable_values(::NumSymbolCache, p::Vector{Float64}, tsidx) |
| 610 | + if tsidx == 1 |
| 611 | + return p[1:2] |
| 612 | + else |
| 613 | + return p[3:4] |
| 614 | + end |
| 615 | + end |
| 616 | + |
| 617 | + @variables x(t) ud1(t) ud2(t) xd1(t) xd2(t) |
| 618 | + @parameters kp |
| 619 | + sc = SymbolCache([x], |
| 620 | + Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5), |
| 621 | + t; |
| 622 | + timeseries_parameters = Dict( |
| 623 | + ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2), |
| 624 | + ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2))) |
| 625 | + sys = NumSymbolCache(sc) |
| 626 | + |
| 627 | + function f!(du, u, p, t) |
| 628 | + du .= u .* t .+ p[5] * sum(u) |
| 629 | + end |
| 630 | + fn = ODEFunction(f!; sys = sys) |
| 631 | + prob = ODEProblem(fn, [1.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0, 5.0]) |
| 632 | + cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true, |
| 633 | + save_positions = (false, false)) do integ |
| 634 | + integ.p[1:2] .+= exp(-integ.t) |
| 635 | + SciMLBase.save_discretes!(integ, 1) |
| 636 | + end |
| 637 | + function affect2!(integ) |
| 638 | + integ.p[3:4] .+= only(integ.u) |
| 639 | + SciMLBase.save_discretes!(integ, 2) |
| 640 | + end |
| 641 | + cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false), |
| 642 | + initialize = (c, u, t, integ) -> affect2!(integ)) |
| 643 | + sol = solve(deepcopy(prob), Tsit5(); callback = CallbackSet(cb1, cb2)) |
| 644 | + |
| 645 | + ud1val = getindex.(sol.discretes.collection[1].u, 1) |
| 646 | + xd1val = getindex.(sol.discretes.collection[1].u, 2) |
| 647 | + ud2val = getindex.(sol.discretes.collection[2].u, 1) |
| 648 | + xd2val = getindex.(sol.discretes.collection[2].u, 2) |
| 649 | + |
| 650 | + for (sym, timeseries_index, val, buffer, isobs, check_inference) in [(ud1, |
| 651 | + 1, |
| 652 | + ud1val, |
| 653 | + zeros(length(ud1val)), |
| 654 | + false, |
| 655 | + true) |
| 656 | + ([ud1, xd1], |
| 657 | + 1, |
| 658 | + vcat.(ud1val, |
| 659 | + xd1val), |
| 660 | + map( |
| 661 | + _ -> zeros(2), |
| 662 | + ud1val), |
| 663 | + false, |
| 664 | + true) |
| 665 | + ((ud2, xd2), |
| 666 | + 2, |
| 667 | + tuple.(ud2val, |
| 668 | + xd2val), |
| 669 | + map( |
| 670 | + _ -> zeros(2), |
| 671 | + ud2val), |
| 672 | + false, |
| 673 | + true) |
| 674 | + (ud2 + xd2, |
| 675 | + 2, |
| 676 | + ud2val .+ |
| 677 | + xd2val, |
| 678 | + zeros(length(ud2val)), |
| 679 | + true, |
| 680 | + true) |
| 681 | + ( |
| 682 | + [ud2 + xd2, |
| 683 | + ud2 * xd2], |
| 684 | + 2, |
| 685 | + vcat.( |
| 686 | + ud2val .+ |
| 687 | + xd2val, |
| 688 | + ud2val .* |
| 689 | + xd2val), |
| 690 | + map( |
| 691 | + _ -> zeros(2), |
| 692 | + ud2val), |
| 693 | + true, |
| 694 | + true) |
| 695 | + ( |
| 696 | + (ud1 + xd1, |
| 697 | + ud1 * xd1), |
| 698 | + 1, |
| 699 | + tuple.( |
| 700 | + ud1val .+ |
| 701 | + xd1val, |
| 702 | + ud1val .* |
| 703 | + xd1val), |
| 704 | + map( |
| 705 | + _ -> zeros(2), |
| 706 | + ud1val), |
| 707 | + true, |
| 708 | + true)] |
| 709 | + getter = getp(sys, sym) |
| 710 | + if check_inference |
| 711 | + @inferred getter(sol) |
| 712 | + @inferred getter(deepcopy(buffer), sol) |
| 713 | + if !isobs |
| 714 | + @inferred getter(parameter_values(sol)) |
| 715 | + if !(eltype(val) <: Number) |
| 716 | + @inferred getter(deepcopy(buffer[1]), parameter_values(sol)) |
| 717 | + end |
| 718 | + end |
| 719 | + end |
| 720 | + |
| 721 | + @test getter(sol) == val |
| 722 | + if eltype(val) <: Number |
| 723 | + target = val |
| 724 | + else |
| 725 | + target = collect.(val) |
| 726 | + end |
| 727 | + tmp = deepcopy(buffer) |
| 728 | + getter(tmp, sol) |
| 729 | + @test tmp == target |
| 730 | + |
| 731 | + if !isobs |
| 732 | + @test getter(parameter_values(sol)) == val[end] |
| 733 | + if !(eltype(val) <: Number) |
| 734 | + target = collect(val[end]) |
| 735 | + tmp = deepcopy(buffer)[end] |
| 736 | + getter(tmp, parameter_values(sol)) |
| 737 | + @test tmp == target |
| 738 | + end |
| 739 | + end |
| 740 | + |
| 741 | + for subidx in [ |
| 742 | + 1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5] |
| 743 | + if check_inference |
| 744 | + @inferred getter(sol, subidx) |
| 745 | + if !isa(val[subidx], Number) |
| 746 | + @inferred getter(deepcopy(buffer[subidx]), sol, subidx) |
| 747 | + end |
| 748 | + end |
| 749 | + @test getter(sol, subidx) == val[subidx] |
| 750 | + tmp = deepcopy(buffer[subidx]) |
| 751 | + if val[subidx] isa Number |
| 752 | + continue |
| 753 | + end |
| 754 | + target = val[subidx] |
| 755 | + if eltype(target) <: Number |
| 756 | + target = collect(target) |
| 757 | + else |
| 758 | + target = collect.(target) |
| 759 | + end |
| 760 | + getter(tmp, sol, subidx) |
| 761 | + @test tmp == target |
| 762 | + end |
| 763 | + end |
| 764 | + |
| 765 | + for sym in [ |
| 766 | + [ud1, xd1, ud2], |
| 767 | + (ud2, xd1, xd2), |
| 768 | + ud1 + ud2, |
| 769 | + [ud1 + ud2, ud1 * xd1], |
| 770 | + (ud1 + ud2, ud1 * xd1)] |
| 771 | + getter = getp(sys, sym) |
| 772 | + @test_throws Exception getter(sol) |
| 773 | + @test_throws Exception getter([], sol) |
| 774 | + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] |
| 775 | + @test_throws Exception getter(sol, subidx) |
| 776 | + @test_throws Exception getter([], sol, subidx) |
| 777 | + end |
| 778 | + end |
| 779 | + |
| 780 | + kpval = sol.prob.p[5] |
| 781 | + xval = getindex.(sol.u) |
| 782 | + |
| 783 | + for (sym, val_is_timeseries, val, check_inference) in [ |
| 784 | + (kp, false, kpval, true), |
| 785 | + ([kp, kp], false, [kpval, kpval], true), |
| 786 | + ((kp, kp), false, (kpval, kpval), true), |
| 787 | + (ud2, true, ud2val, true), |
| 788 | + ([ud2, kp], true, vcat.(ud2val, kpval), false), |
| 789 | + ((ud1, kp), true, tuple.(ud1val, kpval), false), |
| 790 | + ([kp, x], true, vcat.(kpval, xval), false), |
| 791 | + ((kp, x), true, tuple.(kpval, xval), false), |
| 792 | + (2ud2, true, 2 .* ud2val, true), |
| 793 | + ([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false), |
| 794 | + ((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false) |
| 795 | + ] |
| 796 | + getter = getu(sys, sym) |
| 797 | + if check_inference |
| 798 | + @inferred getter(sol) |
| 799 | + end |
| 800 | + @test getter(sol) == val |
| 801 | + reference = val_is_timeseries ? val : xval |
| 802 | + for subidx in [ |
| 803 | + 1, CartesianIndex(2), :, rand(Bool, length(reference)), |
| 804 | + rand(eachindex(reference), 4), 2:6 |
| 805 | + ] |
| 806 | + if check_inference |
| 807 | + @inferred getter(sol, subidx) |
| 808 | + end |
| 809 | + target = if val_is_timeseries |
| 810 | + val[subidx] |
| 811 | + else |
| 812 | + val |
| 813 | + end |
| 814 | + @test getter(sol, subidx) == target |
| 815 | + end |
| 816 | + end |
| 817 | + |
| 818 | + _xval = xval[1] |
| 819 | + _ud1val = ud1val[1] |
| 820 | + _ud2val = ud2val[1] |
| 821 | + _xd1val = xd1val[1] |
| 822 | + _xd2val = xd2val[1] |
| 823 | + integ = init(prob, Tsit5(); callback = CallbackSet(cb1, cb2)) |
| 824 | + for (sym, val, check_inference) in [ |
| 825 | + ([x, ud1], [_xval, _ud1val], false), |
| 826 | + ((x, ud1), (_xval, _ud1val), true), |
| 827 | + (x + ud2, _xval + _ud2val, true), |
| 828 | + ([2x, 3xd1], [2_xval, 3_xd1val], true), |
| 829 | + ((2x, 3xd2), (2_xval, 3_xd2val), true) |
| 830 | + ] |
| 831 | + getter = getu(sys, sym) |
| 832 | + @test_throws Exception getter(sol) |
| 833 | + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] |
| 834 | + @test_throws Exception getter(sol, subidx) |
| 835 | + end |
| 836 | + |
| 837 | + if check_inference |
| 838 | + @inferred getter(integ) |
| 839 | + end |
| 840 | + @test getter(integ) == val |
| 841 | + end |
| 842 | + |
| 843 | + xinterp = sol(0.1:0.1:0.3, idxs = x) |
| 844 | + xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs = x) |
| 845 | + ud1interp = ud1val[2:4] |
| 846 | + ud2interp = ud2val[2:4] |
| 847 | + |
| 848 | + c1 = SciMLBase.Clock(0.1) |
| 849 | + c2 = SciMLBase.SolverStepClock |
| 850 | + for (sym, t, val) in [ |
| 851 | + (x, c1[2], xinterp[1]), |
| 852 | + (x, c1[2:4], xinterp), |
| 853 | + ([x, ud1], c1[2], [xinterp[1], ud1interp[1]]), |
| 854 | + ([x, ud1], c1[2:4], vcat.(xinterp, ud1interp)), |
| 855 | + (x, c2[2], xinterp2[1]), |
| 856 | + (x, c2[2:4], xinterp2), |
| 857 | + ([x, ud2], c2[2], [xinterp2[1], ud2interp[1]]), |
| 858 | + ([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp)) |
| 859 | + ] |
| 860 | + res = sol(t, idxs = sym) |
| 861 | + if res isa DiffEqArray |
| 862 | + res = res.u |
| 863 | + end |
| 864 | + @test res == val |
| 865 | + end |
| 866 | +end |
0 commit comments