Skip to content

Commit 14cddf0

Browse files
author
Michael Abbott
committed
use nograd for functions too
1 parent 5439bc3 commit 14cddf0

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

src/symbolic.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ function insert_symbolic_gradient(axislist, store)
3232

3333
inbody, prebody = [], []
3434
for (dt, t) in unique(targets)
35-
drdt = leibnitz(store.right, t)
35+
drdt = leibnitz(store.right, t, store.nograd)
3636
deltar = if store.finaliser == :identity
3737
simplitimes(simpliconj(drdt), :($dZ[$(store.leftraw...)]))
3838
else
3939
rhs = :($ZED[$(store.leftraw...)])
40-
dldr = leibfinal(store.finaliser, rhs)
40+
dldr = leibfinal(store.finaliser, rhs, store.nograd)
4141
simplitimes(simpliconj(drdt), simpliconj(dldr), :($dZ[$(store.leftraw...)]))
4242
end
4343
if store.redfun == :+
@@ -84,16 +84,16 @@ function insert_symbolic_gradient(axislist, store)
8484

8585
end
8686

87-
leibfinal(fun::Symbol, res) =
87+
leibfinal(fun::Symbol, res, no=()) =
8888
if fun == :log
8989
:(exp(-$res)) # this exp gets done at every element :(
9090
# :(inv(exp($res)))
9191
else
92-
_leibfinal(:($fun($RHS)), res)
92+
_leibfinal(:($fun($RHS)), res, no)
9393
end
9494

95-
_leibfinal(out, res) = begin
96-
grad1 = leibnitz(out, RHS)
95+
_leibfinal(out, res, no) = begin
96+
grad1 = leibnitz(out, RHS, no)
9797
grad2 = MacroTools_postwalk(grad1) do ex
9898
# @show ex ex == out
9999
ex == out ? res : ex
@@ -103,13 +103,13 @@ _leibfinal(out, res) = begin
103103
end
104104
end
105105

106-
leibfinal(ex::Expr, res) = begin
106+
leibfinal(ex::Expr, res, no=()) = begin
107107
if ex.head == :call && ex.args[1] isa Expr &&
108108
ex.args[1].head == :(->) && ex.args[1].args[1] == RHS # then it came from underscores
109109
inner = ex.args[1].args[2]
110110
if inner isa Expr && inner.head == :block
111111
lines = filter(a -> !(a isa LineNumberNode), inner.args)
112-
length(lines) == 1 && return _leinfinal(first(lines), res)
112+
length(lines) == 1 && return _leibfinal(first(lines), res, no) # not tested!
113113
end
114114
end
115115
throw("couldn't understand finaliser")
@@ -191,9 +191,9 @@ symbwalk(targets, store) = ex -> begin
191191
return ex
192192
end
193193

194-
leibnitz(s::Number, target) = 0
195-
leibnitz(s::Symbol, target) = s == target ? 1 : 0
196-
leibnitz(ex::Expr, target) = begin
194+
leibnitz(s::Number, target, no=()) = 0
195+
leibnitz(s::Symbol, target, no=()) = s == target ? 1 : 0
196+
leibnitz(ex::Expr, target, no=()) = begin
197197
ex == target && return 1
198198
@capture_(ex, B_[ijk__]) && return 0
199199
if ex.head == Symbol("'")
@@ -202,34 +202,35 @@ leibnitz(ex::Expr, target) = begin
202202
end
203203
ex.head == :call || throw("expected a functionn call, got $ex.")
204204
fun = ex.args[1]
205+
fun in no && return 0
205206
if fun == :log # catch log(a*b) and especially log(a/b)
206207
arg = ex.args[2]
207208
if arg isa Expr && arg.args[1] == :* && length(arg.args) == 3
208209
newex = :(log($(arg.args[2])) + log($(arg.args[3])))
209-
return leibnitz(newex, target)
210+
return leibnitz(newex, target, no)
210211
elseif arg isa Expr && arg.args[1] == :/
211212
newex = :(log($(arg.args[2])) - log($(arg.args[3])))
212-
return leibnitz(newex, target)
213+
return leibnitz(newex, target, no)
213214
end
214215
end
215216
if length(ex.args) == 2 # one-arg function
216217
fx = mydiffrule(fun, ex.args[2])
217-
dx = leibnitz(ex.args[2], target)
218+
dx = leibnitz(ex.args[2], target, no)
218219
return simplitimes(fx, dx)
219220
elseif length(ex.args) == 3 # two-arg function
220221
fx, fy = mydiffrule(fun, ex.args[2:end]...)
221-
dx = leibnitz(ex.args[2], target)
222-
dy = leibnitz(ex.args[3], target)
222+
dx = leibnitz(ex.args[2], target, no)
223+
dy = leibnitz(ex.args[3], target, no)
223224
return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy))
224225
elseif fun in [:+, :*]
225-
fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target)
226-
dxs = [leibnitz(x, target) for x in ex.args[2:end]]
226+
fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target, no)
227+
dxs = [leibnitz(x, target, no) for x in ex.args[2:end]]
227228
fun == :+ && return simpliplus(dxs...)
228229
elseif length(ex.args) == 4 # three-arg function such as ifelse
229230
fx, fy, fz = mydiffrule(fun, ex.args[2:end]...)
230-
dx = leibnitz(ex.args[2], target)
231-
dy = leibnitz(ex.args[3], target)
232-
dz = leibnitz(ex.args[4], target)
231+
dx = leibnitz(ex.args[2], target, no)
232+
dy = leibnitz(ex.args[3], target, no)
233+
dz = leibnitz(ex.args[4], target, no)
233234
return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy), simplitimes(fz, dz))
234235
end
235236
throw("don't know how to handle $ex.")

0 commit comments

Comments
 (0)