@@ -32,12 +32,12 @@ function insert_symbolic_gradient(axislist, store)
32
32
33
33
inbody, prebody = [], []
34
34
for (dt, t) in unique (targets)
35
- drdt = leibnitz (store. right, t)
35
+ drdt = leibnitz (store. right, t, store . nograd )
36
36
deltar = if store. finaliser == :identity
37
37
simplitimes (simpliconj (drdt), :($ dZ[$ (store. leftraw... )]))
38
38
else
39
39
rhs = :($ ZED[$ (store. leftraw... )])
40
- dldr = leibfinal (store. finaliser, rhs)
40
+ dldr = leibfinal (store. finaliser, rhs, store . nograd )
41
41
simplitimes (simpliconj (drdt), simpliconj (dldr), :($ dZ[$ (store. leftraw... )]))
42
42
end
43
43
if store. redfun == :+
@@ -84,16 +84,16 @@ function insert_symbolic_gradient(axislist, store)
84
84
85
85
end
86
86
87
- leibfinal (fun:: Symbol , res) =
87
+ leibfinal (fun:: Symbol , res, no = () ) =
88
88
if fun == :log
89
89
:(exp (- $ res)) # this exp gets done at every element :(
90
90
# :(inv(exp($res)))
91
91
else
92
- _leibfinal (:($ fun ($ RHS)), res)
92
+ _leibfinal (:($ fun ($ RHS)), res, no )
93
93
end
94
94
95
- _leibfinal (out, res) = begin
96
- grad1 = leibnitz (out, RHS)
95
+ _leibfinal (out, res, no ) = begin
96
+ grad1 = leibnitz (out, RHS, no )
97
97
grad2 = MacroTools_postwalk (grad1) do ex
98
98
# @show ex ex == out
99
99
ex == out ? res : ex
@@ -103,13 +103,13 @@ _leibfinal(out, res) = begin
103
103
end
104
104
end
105
105
106
- leibfinal (ex:: Expr , res) = begin
106
+ leibfinal (ex:: Expr , res, no = () ) = begin
107
107
if ex. head == :call && ex. args[1 ] isa Expr &&
108
108
ex. args[1 ]. head == :(-> ) && ex. args[1 ]. args[1 ] == RHS # then it came from underscores
109
109
inner = ex. args[1 ]. args[2 ]
110
110
if inner isa Expr && inner. head == :block
111
111
lines = filter (a -> ! (a isa LineNumberNode), inner. args)
112
- length (lines) == 1 && return _leinfinal (first (lines), res)
112
+ length (lines) == 1 && return _leibfinal (first (lines), res, no) # not tested!
113
113
end
114
114
end
115
115
throw (" couldn't understand finaliser" )
@@ -191,9 +191,9 @@ symbwalk(targets, store) = ex -> begin
191
191
return ex
192
192
end
193
193
194
- leibnitz (s:: Number , target) = 0
195
- leibnitz (s:: Symbol , target) = s == target ? 1 : 0
196
- leibnitz (ex:: Expr , target) = begin
194
+ leibnitz (s:: Number , target, no = () ) = 0
195
+ leibnitz (s:: Symbol , target, no = () ) = s == target ? 1 : 0
196
+ leibnitz (ex:: Expr , target, no = () ) = begin
197
197
ex == target && return 1
198
198
@capture_ (ex, B_[ijk__]) && return 0
199
199
if ex. head == Symbol (" '" )
@@ -202,34 +202,35 @@ leibnitz(ex::Expr, target) = begin
202
202
end
203
203
ex. head == :call || throw (" expected a functionn call, got $ex ." )
204
204
fun = ex. args[1 ]
205
+ fun in no && return 0
205
206
if fun == :log # catch log(a*b) and especially log(a/b)
206
207
arg = ex. args[2 ]
207
208
if arg isa Expr && arg. args[1 ] == :* && length (arg. args) == 3
208
209
newex = :(log ($ (arg. args[2 ])) + log ($ (arg. args[3 ])))
209
- return leibnitz (newex, target)
210
+ return leibnitz (newex, target, no )
210
211
elseif arg isa Expr && arg. args[1 ] == :/
211
212
newex = :(log ($ (arg. args[2 ])) - log ($ (arg. args[3 ])))
212
- return leibnitz (newex, target)
213
+ return leibnitz (newex, target, no )
213
214
end
214
215
end
215
216
if length (ex. args) == 2 # one-arg function
216
217
fx = mydiffrule (fun, ex. args[2 ])
217
- dx = leibnitz (ex. args[2 ], target)
218
+ dx = leibnitz (ex. args[2 ], target, no )
218
219
return simplitimes (fx, dx)
219
220
elseif length (ex. args) == 3 # two-arg function
220
221
fx, fy = mydiffrule (fun, ex. args[2 : end ]. .. )
221
- dx = leibnitz (ex. args[2 ], target)
222
- dy = leibnitz (ex. args[3 ], target)
222
+ dx = leibnitz (ex. args[2 ], target, no )
223
+ dy = leibnitz (ex. args[3 ], target, no )
223
224
return simpliplus (simplitimes (fx, dx), simplitimes (fy, dy))
224
225
elseif fun in [:+ , :* ]
225
- fun == :* && return leibnitz (:(* ($ (ex. args[2 ]), * ($ (ex. args[3 : end ]. .. )))), target)
226
- dxs = [leibnitz (x, target) for x in ex. args[2 : end ]]
226
+ fun == :* && return leibnitz (:(* ($ (ex. args[2 ]), * ($ (ex. args[3 : end ]. .. )))), target, no )
227
+ dxs = [leibnitz (x, target, no ) for x in ex. args[2 : end ]]
227
228
fun == :+ && return simpliplus (dxs... )
228
229
elseif length (ex. args) == 4 # three-arg function such as ifelse
229
230
fx, fy, fz = mydiffrule (fun, ex. args[2 : end ]. .. )
230
- dx = leibnitz (ex. args[2 ], target)
231
- dy = leibnitz (ex. args[3 ], target)
232
- dz = leibnitz (ex. args[4 ], target)
231
+ dx = leibnitz (ex. args[2 ], target, no )
232
+ dy = leibnitz (ex. args[3 ], target, no )
233
+ dz = leibnitz (ex. args[4 ], target, no )
233
234
return simpliplus (simplitimes (fx, dx), simplitimes (fy, dy), simplitimes (fz, dz))
234
235
end
235
236
throw (" don't know how to handle $ex ." )
0 commit comments