Skip to content

Commit cc4ab73

Browse files
feat: support substituting CallWithMetadata in expressions
1 parent 85a06f9 commit cc4ab73

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/num.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,18 @@ end
7878
substitute(expr, s::Pair; kw...) = substituter([s[1] => s[2]])(expr; kw...)
7979
substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...)
8080

81-
substituter(pair::Pair) = substituter((pair,))
81+
function _unwrap_callwithmeta(x)
82+
x = value(x)
83+
return x isa CallWithMetadata ? x.f : x
84+
end
85+
function subrules_to_dict(pairs)
86+
if pairs isa Pair
87+
pairs = (pairs,)
88+
end
89+
return Dict(_unwrap_callwithmeta(k) => value(v) for (k, v) in pairs)
90+
end
8291
function substituter(pairs)
83-
dict = Dict(value(k) => value(v) for (k, v) in pairs)
92+
dict = subrules_to_dict(pairs)
8493
(expr; kw...) -> SymbolicUtils.substitute(value(expr), dict; kw...)
8594
end
8695

src/variable.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ infinite loops in cases where the substitutions in `dict` are circular
526526
See also: [`fast_substitute`](@ref).
527527
"""
528528
function fixpoint_sub(x, dict; operator = Nothing, maxiters = 10000)
529+
dict = subrules_to_dict(dict)
529530
y = fast_substitute(x, dict; operator)
530531
while !isequal(x, y) && maxiters > 0
531532
y = x

test/utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Symbolics
2-
using Symbolics: symbolic_to_float, var_from_nested_derivative
2+
using Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap
33

44
@testset "get_variables" begin
55
@variables t x y z(t)
@@ -46,3 +46,12 @@ end
4646
expr = Symbolics.fixpoint_sub(x, Dict(x => y, y => x); maxiters = 9)
4747
@test isequal(expr, y)
4848
end
49+
50+
@testset "Issue#1342 substitute working on called symbolics" begin
51+
@variables p(..) x y
52+
arg = unwrap(substitute(p(x), [p => identity]))
53+
@test iscall(arg) && operation(arg) == identity && isequal(only(arguments(arg)), x)
54+
@test unwrap(substitute(p(x), [p => sqrt, x => 4.0])) 2.0
55+
arg = Symbolics.fixpoint_sub(p(x), [p => sqrt, x => 2y + 3, y => 1.0 + p(4)])
56+
@test arg 3.0
57+
end

0 commit comments

Comments
 (0)