Skip to content

Commit ecb8f25

Browse files
Merge pull request #782 from AayushSabharwal/as/fix-remake-buffer
fix: update to new `remake_buffer` signature
2 parents a121f72 + 0326659 commit ecb8f25

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ StableRNGs = "1.0"
8989
StaticArrays = "1.7"
9090
StaticArraysCore = "1.4"
9191
Statistics = "1.10"
92-
SymbolicIndexingInterface = "0.3.26"
92+
SymbolicIndexingInterface = "0.3.30"
9393
Tables = "1.11"
9494
Zygote = "0.6.67"
9595
julia = "1.10"

src/remake.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -617,56 +617,56 @@ end
617617

618618
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
619619
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
620-
isdep || return remake_buffer(prob, state_values(prob), u0), p
620+
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
621621

622622
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
623623
for (k, v) in u0)
624624

625625
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
626-
isdep || return remake_buffer(prob, state_values(prob), u0), p
626+
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
627627

628628
# FIXME: need to provide `u` since the observed function expects it.
629629
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
630630
# used, since any state symbols in the expression were substituted out earlier.
631631
temp_state = ProblemState(; u = state_values(prob), p = p)
632632
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
633633
for (k, v) in u0)
634-
return remake_buffer(prob, state_values(prob), u0), p
634+
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
635635
end
636636

637637
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
638638
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
639-
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)
639+
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
640640

641641
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
642642
for (k, v) in p)
643643

644644
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
645-
isdep || return u0, remake_buffer(prob, parameter_values(prob), p)
645+
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
646646

647647
# FIXME: need to provide `p` since the observed function expects an `MTKParameters`
648648
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
649649
# used, since any parameter symbols in the expression were substituted out earlier.
650650
temp_state = ProblemState(; u = u0, p = parameter_values(prob))
651651
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
652652
for (k, v) in p)
653-
return u0, remake_buffer(prob, parameter_values(prob), p)
653+
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
654654
end
655655

656656
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
657657
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
658658
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
659659

660660
if !isu0dep && !ispdep
661-
return remake_buffer(prob, state_values(prob), u0),
662-
remake_buffer(prob, parameter_values(prob), p)
661+
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
662+
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
663663
end
664664
if !isu0dep
665-
u0 = remake_buffer(prob, state_values(prob), u0)
665+
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
666666
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
667667
end
668668
if !ispdep
669-
p = remake_buffer(prob, parameter_values(prob), p)
669+
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
670670
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
671671
end
672672

@@ -675,8 +675,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
675675
for (k, v) in u0)
676676
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
677677
for (k, v) in p)
678-
return remake_buffer(prob, state_values(prob), u0),
679-
remake_buffer(prob, parameter_values(prob), p)
678+
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
679+
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
680680
end
681681

682682
function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)

0 commit comments

Comments
 (0)