|
| 1 | +function old_sparsejacobian(ops::AbstractVector, vars::AbstractVector) |
| 2 | + sp = Symbolics.jacobian_sparsity(ops, vars) |
| 3 | + I,J,_ = findnz(sp) |
| 4 | + |
| 5 | + exprs = old_sparsejacobian_vals(ops, vars, I, J) |
| 6 | + |
| 7 | + sparse(I, J, exprs, length(ops), length(vars)) |
| 8 | +end |
| 9 | + |
| 10 | +function old_sparsejacobian_vals(ops::AbstractVector, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false, kwargs...) |
| 11 | + exprs = Num[] |
| 12 | + sizehint!(exprs, length(I)) |
| 13 | + |
| 14 | + for (i,j) in zip(I, J) |
| 15 | + push!(exprs, Num(old_expand_derivatives(Differential(vars[j])(ops[i]), simplify; kwargs...))) |
| 16 | + end |
| 17 | + exprs |
| 18 | +end |
| 19 | + |
| 20 | + |
| 21 | +function old_expand_derivatives(O::SymbolicUtils.Symbolic, simplify=false; throw_no_derivative=false) |
| 22 | + if iscall(O) && isa(operation(O), Differential) |
| 23 | + arg = only(arguments(O)) |
| 24 | + arg = old_expand_derivatives(arg, false; throw_no_derivative) |
| 25 | + return old_executediff(operation(O), arg, simplify; throw_no_derivative) |
| 26 | + elseif iscall(O) && isa(operation(O), Integral) |
| 27 | + return operation(O)(old_expand_derivatives(arguments(O)[1]; throw_no_derivative)) |
| 28 | + elseif !Symbolics.hasderiv(O) |
| 29 | + return O |
| 30 | + else |
| 31 | + args = map(a->old_expand_derivatives(a, false; throw_no_derivative), arguments(O)) |
| 32 | + O1 = operation(O)(args...) |
| 33 | + return simplify ? SymbolicUtils.simplify(O1) : O1 |
| 34 | + end |
| 35 | +end |
| 36 | +function old_expand_derivatives(n::Num, simplify=false; kwargs...) |
| 37 | + Symbolics.wrap(old_expand_derivatives(Symbolics.value(n), simplify; kwargs...)) |
| 38 | +end |
| 39 | + |
| 40 | +function old_occursin_info(x, expr, fail = true) |
| 41 | + if SymbolicUtils.symtype(expr) <: AbstractArray |
| 42 | + if fail |
| 43 | + error("Differentiation with array expressions is not yet supported") |
| 44 | + else |
| 45 | + return occursin(x, expr) |
| 46 | + end |
| 47 | + end |
| 48 | + |
| 49 | + # Allow scalarized expressions |
| 50 | + function is_scalar_indexed(ex) |
| 51 | + (iscall(ex) && operation(ex) == getindex && !(SymbolicUtils.symtype(ex) <: AbstractArray)) || |
| 52 | + (iscall(ex) && (SymbolicUtils.issym(operation(ex)) || iscall(operation(ex))) && |
| 53 | + is_scalar_indexed(operation(ex))) |
| 54 | + end |
| 55 | + |
| 56 | + # x[1] == x[1] but not x[2] |
| 57 | + if is_scalar_indexed(x) && is_scalar_indexed(expr) && |
| 58 | + isequal(first(arguments(x)), first(arguments(expr))) |
| 59 | + return isequal(operation(x), operation(expr)) && |
| 60 | + isequal(arguments(x), arguments(expr)) |
| 61 | + end |
| 62 | + |
| 63 | + if is_scalar_indexed(x) && is_scalar_indexed(expr) && |
| 64 | + !occursin(first(arguments(x)), first(arguments(expr))) |
| 65 | + return false |
| 66 | + end |
| 67 | + |
| 68 | + if is_scalar_indexed(expr) && !is_scalar_indexed(x) && !occursin(x, expr) |
| 69 | + return false |
| 70 | + end |
| 71 | + |
| 72 | + !iscall(expr) && return isequal(x, expr) |
| 73 | + if isequal(x, expr) |
| 74 | + true |
| 75 | + else |
| 76 | + args = map(a->old_occursin_info(x, a, operation(expr) !== getindex), arguments(expr)) |
| 77 | + if all(_isfalse, args) |
| 78 | + return false |
| 79 | + end |
| 80 | + Term{Real}(true, args) |
| 81 | + end |
| 82 | +end |
| 83 | + |
| 84 | +function old_occursin_info(x, expr::Sym, fail) |
| 85 | + if SymbolicUtils.symtype(expr) <: AbstractArray && fail |
| 86 | + error("Differentiation of expressions involving arrays and array variables is not yet supported.") |
| 87 | + end |
| 88 | + isequal(x, expr) |
| 89 | +end |
| 90 | + |
| 91 | +_isfalse(occ::Bool) = occ === false |
| 92 | +_isfalse(occ::SymbolicUtils.Symbolic) = iscall(occ) && _isfalse(operation(occ)) |
| 93 | + |
| 94 | +_iszero(x) = false |
| 95 | +_isone(x) = false |
| 96 | +_iszero(x::Number) = iszero(x) |
| 97 | +_isone(x::Number) = isone(x) |
| 98 | +_iszero(::SymbolicUtils.Symbolic) = false |
| 99 | +_isone(::SymbolicUtils.Symbolic) = false |
| 100 | +_iszero(x::Num) = _iszero(value(x))::Bool |
| 101 | +_isone(x::Num) = _isone(value(x))::Bool |
| 102 | + |
| 103 | + |
| 104 | +function old_executediff(D, arg, simplify=false; occurrences=nothing, throw_no_derivative=false) |
| 105 | + if occurrences == nothing |
| 106 | + occurrences = old_occursin_info(D.x, arg) |
| 107 | + end |
| 108 | + |
| 109 | + _isfalse(occurrences) && return 0 |
| 110 | + occurrences isa Bool && return 1 # means it's a `true` |
| 111 | + |
| 112 | + if !iscall(arg) |
| 113 | + return D(arg) # Cannot expand |
| 114 | + elseif (op = operation(arg); SymbolicUtils.issym(op)) |
| 115 | + inner_args = arguments(arg) |
| 116 | + if any(isequal(D.x), inner_args) |
| 117 | + return D(arg) # base case if any argument is directly equal to the i.v. |
| 118 | + else |
| 119 | + return sum(inner_args, init=0) do a |
| 120 | + return old_executediff(Differential(a), arg; throw_no_derivative) * |
| 121 | + old_executediff(D, a; throw_no_derivative) |
| 122 | + end |
| 123 | + end |
| 124 | + elseif op === getindex |
| 125 | + inner_args = arguments(arguments(arg)[1]) |
| 126 | + c = 0 |
| 127 | + for a in inner_args |
| 128 | + if isequal(a, D.x) |
| 129 | + return D(arg) |
| 130 | + else |
| 131 | + c += Differential(a)(arg) * D(a) |
| 132 | + end |
| 133 | + end |
| 134 | + return old_expand_derivatives(c) |
| 135 | + elseif op === ifelse |
| 136 | + args = arguments(arg) |
| 137 | + O = op(args[1], |
| 138 | + old_executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2], throw_no_derivative), |
| 139 | + old_executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3], throw_no_derivative)) |
| 140 | + return O |
| 141 | + elseif isa(op, Differential) |
| 142 | + # The recursive expand_derivatives was not able to remove |
| 143 | + # a nested Differential. We can attempt to differentiate the |
| 144 | + # inner expression wrt to the outer iv. And leave the |
| 145 | + # unexpandable Differential outside. |
| 146 | + if isequal(op.x, D.x) |
| 147 | + return D(arg) |
| 148 | + else |
| 149 | + inner = old_executediff(D, arguments(arg)[1], false; throw_no_derivative) |
| 150 | + # if the inner expression is not expandable either, return |
| 151 | + if iscall(inner) && operation(inner) isa Differential |
| 152 | + return D(arg) |
| 153 | + else |
| 154 | + # otherwise give the nested Differential another try |
| 155 | + return old_executediff(op, inner, simplify; throw_no_derivative) |
| 156 | + end |
| 157 | + end |
| 158 | + elseif isa(op, Integral) |
| 159 | + if isa(op.domain.domain, Symbolics.AbstractInterval) |
| 160 | + domain = op.domain.domain |
| 161 | + a, b = Symbolics.DomainSets.endpoints(domain) |
| 162 | + c = 0 |
| 163 | + inner_function = arguments(arg)[1] |
| 164 | + if iscall(value(a)) |
| 165 | + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a))) |
| 166 | + t2 = D(a) |
| 167 | + c -= t1*t2 |
| 168 | + end |
| 169 | + if iscall(value(b)) |
| 170 | + t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b))) |
| 171 | + t2 = D(b) |
| 172 | + c += t1*t2 |
| 173 | + end |
| 174 | + inner = old_executediff(D, arguments(arg)[1]; throw_no_derivative) |
| 175 | + c += op(inner) |
| 176 | + return Symbolics.value(c) |
| 177 | + end |
| 178 | + end |
| 179 | + |
| 180 | + inner_args = arguments(arg) |
| 181 | + l = length(inner_args) |
| 182 | + exprs = [] |
| 183 | + c = 0 |
| 184 | + |
| 185 | + for i in 1:l |
| 186 | + t2 = old_executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i], throw_no_derivative) |
| 187 | + |
| 188 | + x = if _iszero(t2) |
| 189 | + t2 |
| 190 | + elseif _isone(t2) |
| 191 | + d = Symbolics.derivative_idx(arg, i) |
| 192 | + if d isa Symbolics.NoDeriv |
| 193 | + throw_no_derivative && error((arg, i)) |
| 194 | + D(arg) |
| 195 | + else |
| 196 | + d |
| 197 | + end |
| 198 | + else |
| 199 | + t1 = Symbolics.derivative_idx(arg, i) |
| 200 | + t1 = if t1 isa Symbolics.NoDeriv |
| 201 | + throw_no_derivative && error((arg, i)) |
| 202 | + D(arg) |
| 203 | + else |
| 204 | + t1 |
| 205 | + end |
| 206 | + t1 * t2 |
| 207 | + end |
| 208 | + |
| 209 | + if _iszero(x) |
| 210 | + continue |
| 211 | + elseif x isa SymbolicUtils.Symbolic |
| 212 | + push!(exprs, x) |
| 213 | + else |
| 214 | + c += x |
| 215 | + end |
| 216 | + end |
| 217 | + |
| 218 | + if isempty(exprs) |
| 219 | + return c |
| 220 | + elseif length(exprs) == 1 |
| 221 | + term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1]) |
| 222 | + return _iszero(c) ? term : c + term |
| 223 | + else |
| 224 | + x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...) |
| 225 | + return simplify ? SymbolicUtils.simplify(x) : x |
| 226 | + end |
| 227 | +end |
0 commit comments