Skip to content

Commit a9a73af

Browse files
Merge pull request #1480 from AayushSabharwal/as/fix-mutation
fix: remove mutation of `BasicSymbolic`
2 parents d4a7b38 + cefcf72 commit a9a73af

File tree

5 files changed

+48
-39
lines changed

5 files changed

+48
-39
lines changed

src/latexify_recipes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ function _toexpr(O)
216216
while num isa Term && num.f isa Differential
217217
deg += 1
218218
den *= num.f.x
219-
num = num.arguments[1]
219+
num = first(arguments(num))
220220
end
221221
return :(_derivative($(_toexpr(num)), $den, $deg))
222222
elseif op isa Integral

src/solver/ia_main.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
3333
for i in eachindex(lhs_roots)
3434
for j in eachindex(rhs)
3535
if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf
36-
lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false)
36+
_args = copy(parent(arguments(lhs_roots[i])))
37+
_args[1] = substitute(_args[1], Dict(new_var => rhs[j]), fold = false)
38+
T = typeof(lhs_roots[i])
39+
_op = operation(lhs_roots[i])
40+
_meta = metadata(lhs_roots[i])
41+
lhs_roots[i] = maketerm(T, _op, _args, _meta)
3742
push!(roots, lhs_roots[i])
3843
else
3944
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
@@ -86,8 +91,9 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
8691
end
8792

8893
elseif oper === (^)
89-
if any(isequal(x, var) for x in get_variables(args[1])) &&
90-
n_occurrences(args[2], var) == 0 && args[2] isa Integer
94+
var_in_base = any(isequal(x, var) for x in get_variables(args[1]))
95+
var_in_pow = n_occurrences(args[2], var) != 0
96+
if var_in_base && !var_in_pow && args[2] isa Integer
9197
lhs = args[1]
9298
power = args[2]
9399
new_roots = []
@@ -111,11 +117,10 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
111117
end
112118
rhs = []
113119
append!(rhs, new_roots)
114-
elseif any(isequal(x, var) for x in get_variables(args[1])) &&
115-
n_occurrences(args[2], var) == 0
120+
elseif var_in_base && !var_in_pow
116121
lhs = args[1]
117-
s, args[2] = filter_stuff(args[2])
118-
rhs = map(sol -> term(^, sol, 1 // args[2]), rhs)
122+
s, power = filter_stuff(args[2])
123+
rhs = map(sol -> term(^, sol, 1 // power), rhs)
119124
else
120125
lhs = args[2]
121126
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)

src/solver/polynomialization.jl

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function turn_to_poly(expr, var)
4343
expr = unwrap(expr)
4444
!iscall(expr) && return (expr, Dict())
4545

46-
args = arguments(expr)
46+
args = copy(parent(arguments(expr)))
4747

4848
sub = 0
4949
broken = Ref(false)
@@ -53,12 +53,12 @@ function turn_to_poly(expr, var)
5353
arg_oper = operation(arg)
5454

5555
if arg_oper === (^)
56-
tp = trav_pow(args, i, var, broken, sub)
56+
args[i], tp = trav_pow(args[i], var, broken, sub)
5757
sub = isequal(tp, false) ? sub : tp
5858
continue
5959
end
6060
if arg_oper === (*)
61-
sub = trav_mult(arg, var, broken, sub)
61+
args[i], sub = trav_mult(arg, var, broken, sub)
6262
continue
6363
end
6464
isequal(add_sub(sub, arg, var, broken), false) && continue
@@ -77,16 +77,17 @@ function turn_to_poly(expr, var)
7777

7878
new_var = gensym()
7979
new_var = (@variables $new_var)[1]
80+
expr = maketerm(typeof(expr), operation(expr), args, metadata(expr))
8081
return ssubs(expr, Dict(sub => new_var)), Dict{Any, Any}(new_var => sub)
8182
end
8283

8384
"""
84-
trav_pow(args, index, var, broken, sub)
85+
trav_pow(arg, var, broken, sub)
8586
86-
Traverses an argument passed from ``turn_to_poly`` if it
87-
satisfies ``oper === (^)``. Returns sub if changed from 0
88-
to a new transcendental function or its value is
89-
kept the same, and false if these 2 cases do not occur.
87+
Traverses an argument `arg` passed from ``turn_to_poly`` if it satisfies
88+
``oper === (^)``. Returns the new `arg` and `sub` if `sub` is changed from 0 to a new
89+
transcendental function or its value is kept the same, or else `false` if these 2 cases
90+
do not occur.
9091
9192
# Arguments
9293
- args: The original arguments array of the expression passed to ``turn_to_poly``
@@ -97,20 +98,20 @@ kept the same, and false if these 2 cases do not occur.
9798
9899
# Examples
99100
```jldoctest
100-
julia> trav_pow([unwrap(9^x)], 1, x, Ref(false), 3^x)
101-
3^x
101+
julia> trav_pow(unwrap(9^x), x, Ref(false), 3^x)
102+
(9^x, 3^x)
102103
103-
julia> trav_pow([unwrap(x^2)], 1, x, Ref(false), 3^x)
104-
false
104+
julia> trav_pow(unwrap(x^2), x, Ref(false), 3^x)
105+
(x^2, false)
105106
```
106107
"""
107-
function trav_pow(args, index, var, broken, sub)
108-
args_arg = arguments(args[index])
108+
function trav_pow(arg, var, broken, sub)
109+
args_arg = arguments(arg)
109110
base = args_arg[1]
110111
power = args_arg[2]
111112

112113
# case 1: log(x)^2 .... 9^x = 3^2^x = 3^2x = (3^x)^2
113-
!isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return base
114+
!isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return arg, base
114115

115116
# case 2: int^f(x)
116117
# n_func_occ may not be strictly 1, we could attempt attracting it after solving
@@ -122,21 +123,20 @@ function trav_pow(args, index, var, broken, sub)
122123
sub = isequal(sub, 0) ? new_b : sub
123124
if !isequal(sub, new_b)
124125
broken[] = true
125-
return false
126+
return arg, false
126127
end
127128
new_b = term(^, new_b, p)
128-
args[index] = new_b
129-
return sub
129+
return new_b, sub
130130
end
131131

132-
return false
132+
return arg, false
133133
end
134134

135135
"""
136136
trav_mult(arg, var, broken, sub)
137137
138138
Traverses an argument passed from ``turn_to_poly`` if it
139-
satisfies ``oper === (*)``. Returns sub whether its changed from 0
139+
satisfies ``oper === (*)``. Returns the new `arg` and `sub` if its changed from 0
140140
to a new transcendental function or its value is
141141
kept the same, but changes broken if these 2 cases do not occur. It
142142
traverses the * argument by sub_arg and compares it to sub using
@@ -151,32 +151,33 @@ the function ``add_sub``
151151
# Examples
152152
```jldoctest
153153
julia> trav_mult(unwrap(9*log(x)), x, Ref(false), log(x))
154-
log(x)
154+
(9log(x), log(x))
155155
156156
julia> trav_mult(unwrap(9*log(x)^2), x, Ref(false), log(x))
157-
log(x)
157+
(9(log(x)^2), log(x))
158158
159159
# value of broken is changed here to true
160160
julia> trav_mult(unwrap(9*log(x+1)), x, Ref(false), log(x))
161-
log(x)
161+
(9log(x + 1), log(x))
162162
```
163163
"""
164164
function trav_mult(arg, var, broken, sub)
165-
args_arg = arguments(arg)
165+
args_arg = copy(parent(arguments(arg)))
166166
for (i, arg2) in enumerate(args_arg)
167167
!iscall(arg2) && continue
168168

169169
oper = operation(arg2)
170170
if oper === (^)
171-
tp = trav_pow(args_arg, i, var, broken, sub)
171+
args_arg[i], tp = trav_pow(args_arg[i], var, broken, sub)
172172
sub = isequal(tp, false) ? sub : tp
173173
continue
174174
end
175175

176176
isequal(add_sub(sub, arg2, var, broken), false) && continue
177177
sub = arg2
178178
end
179-
return sub
179+
arg = maketerm(typeof(arg), operation(arg), args_arg, metadata(arg))
180+
return arg, sub
180181
end
181182

182183
"""

src/solver/preprocess.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function _filter_poly(expr, var)
116116
return filter_stuff(expr)
117117
end
118118

119-
args = arguments(expr)
119+
args = copy(parent(arguments(expr)))
120120
if expr isa ComplexTerm
121121
subs1, subs2 = Dict(), Dict()
122122
expr1, expr2 = 0, 0
@@ -165,7 +165,7 @@ function _filter_poly(expr, var)
165165
end
166166

167167
oper = operation(arg)
168-
monomial = arguments(arg)
168+
monomial = copy(parent(arguments(arg)))
169169
if oper === (^)
170170
if any(arg -> isequal(arg, var), monomial)
171171
continue
@@ -175,6 +175,7 @@ function _filter_poly(expr, var)
175175
subs2, monomial[2] = _filter_poly(monomial[2], var)
176176

177177
merge!(subs, merge(subs1, subs2))
178+
args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg))
178179
continue
179180
end
180181

@@ -196,6 +197,7 @@ function _filter_poly(expr, var)
196197
merge!(subs_of_monom, new_subs)
197198
end
198199
merge!(subs, subs_of_monom)
200+
args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg))
199201
continue
200202
end
201203

@@ -208,9 +210,9 @@ function _filter_poly(expr, var)
208210
end
209211
end
210212

211-
args = map(unwrap, arguments(expr))
213+
args = map(unwrap, args)
212214
oper = operation(expr)
213-
expr = term(oper, args...)
215+
expr = maketerm(typeof(expr), oper, args, metadata(expr))
214216
return subs, expr
215217
end
216218

src/solver/solve_helpers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ function bigify(n)
116116

117117
if n isa SymbolicUtils.BasicSymbolic
118118
!iscall(n) && return n
119-
args = arguments(n)
119+
args = copy(parent(arguments(n)))
120120
for i in eachindex(args)
121121
args[i] = bigify(args[i])
122122
end
123+
n = maketerm(typeof(n), operation(n), args, metadata(n))
123124
return n
124125
end
125126

0 commit comments

Comments
 (0)