Skip to content

Commit 4160122

Browse files
committed
how about we never re-use a variable name for a different thing
1 parent 2a46f82 commit 4160122

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

src/interface.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ function state(o, x)
66
elseif isleaf(x)
77
return nothing
88
else
9-
x, _ = functor(x)
10-
return map(x -> state(o, x), x)
9+
x, _ = functor(x)
10+
return map(xᵢ -> state(o, xᵢ), x)
1111
end
1212
end
1313

1414
function _update(o, st, x, x̄s...)
15-
st, x̄ = apply(o, st, x, x̄s...)
16-
return st, patch(x, x̄)
15+
st, x̄ = apply(o, st, x, x̄s...)
16+
return st, patch(x, x̄)
1717
end
1818

1919
function update(o, state, x::T, x̄s...) where T
@@ -22,10 +22,10 @@ function update(o, state, x::T, x̄s...) where T
2222
elseif isnumeric(x)
2323
return _update(o, state, x, x̄s...)
2424
else
25-
x̄s = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
26-
x, restructure = functor(typeof(x), x)
27-
xstate = map((state, x, x̄s...) -> update(o, state, x, x̄s...), state, x, x̄s...)
28-
return map(first, xstate), restructure(map(last, xstate))
25+
x̄s = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
26+
x′, re = functor(typeof(x), x)
27+
xstate = map((stᵢ, xᵢ, x̄sᵢ...) -> update(o, stᵢ, xᵢ, x̄sᵢ...), state, x, x̄s...)
28+
return map(first, xstate), re(map(last, xstate))
2929
end
3030
end
3131

src/rules.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
108108
function apply(o::RMSProp, state, x, dx)
109109
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
110110
@. acc = ρ * acc + (1 - ρ) * dx^2
111-
dx = @. dx */ (sqrt(acc) + ϵ))
111+
dx = @. dx */ (sqrt(acc) + ϵ))
112112

113-
return acc, dx
113+
return acc, dx
114114
end
115115

116116
(o::RMSProp)(state, m, dm) = update(o, state, m, dm)
@@ -145,9 +145,9 @@ function apply(o::ADAM{T}, state, x, dx) where T
145145

146146
@. mt = β[1] * mt + (one(T) - β[1]) * dx
147147
@. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
148-
dx = @. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
148+
dx = @. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η
149149

150-
return (mt, vt, βt .* β), dx
150+
return (mt, vt, βt .* β), dx
151151
end
152152

153153
"""
@@ -185,12 +185,12 @@ function apply(o::RADAM, state, x, dx)
185185
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
186186
if ρ > 4
187187
r = sqrt((ρ - 4) *- 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
188-
dx = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
188+
dx = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η * r
189189
else
190-
dx = @. mt / (1 - βt[1]) * η
190+
dx = @. mt / (1 - βt[1]) * η
191191
end
192192

193-
return (mt, vt, βt .* β, t + 1), dx
193+
return (mt, vt, βt .* β, t + 1), dx
194194
end
195195

196196
"""
@@ -224,9 +224,9 @@ function apply(o::AdaMax, state, x, dx)
224224

225225
@. mt = β[1] * mt + (1 - β[1]) * dx
226226
@. ut = max(β[2] * ut, abs(dx))
227-
dx = @./(1 - βt[1])) * mt/(ut + ϵ)
227+
dx = @./(1 - βt[1])) * mt/(ut + ϵ)
228228

229-
return (mt, ut, βt .* β), dx
229+
return (mt, ut, βt .* β), dx
230230
end
231231

232232
"""
@@ -263,9 +263,9 @@ function apply(o::OADAM, state, x, dx)
263263
@. vt = β[2] * vt + (1 - β[2]) * dx^2
264264
@. dx = -dx_
265265
@. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
266-
dx = @. dx + 2*dx_
266+
dx = @. dx + 2*dx_
267267

268-
return (mt, vt, βt .* β, dx_), dx
268+
return (mt, vt, βt .* β, dx_), dx
269269
end
270270

271271
"""
@@ -296,9 +296,9 @@ function apply(o::ADAGrad, state, x, dx)
296296
acc = state
297297

298298
@. acc += dx^2
299-
dx = @. dx * η / (sqrt(acc) + ϵ)
299+
dx = @. dx * η / (sqrt(acc) + ϵ)
300300

301-
return acc, dx
301+
return acc, dx
302302
end
303303

304304
"""
@@ -330,10 +330,10 @@ function apply(o::ADADelta, state, x, dx)
330330
@. acc = ρ * acc + (1 - ρ) * dx^2
331331
# DON'T remove epsilon from numerator
332332
# or even out of the square roots
333-
dx = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
333+
dx = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
334334
@. Δacc = ρ * Δacc + (1 - ρ) * dx^2
335335

336-
return (acc, Δacc), dx
336+
return (acc, Δacc), dx
337337
end
338338

339339
"""
@@ -370,9 +370,9 @@ function apply(o::AMSGrad, state, x, dx)
370370
@. mt = β[1] * mt + (1 - β[1]) * dx
371371
@. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
372372
@. v̂t = max(v̂t, vt)
373-
dx = @. η * mt / (sqrt(v̂t) + ϵ)
373+
dx = @. η * mt / (sqrt(v̂t) + ϵ)
374374

375-
return (mt, vt, v̂t), dx
375+
return (mt, vt, v̂t), dx
376376
end
377377

378378
"""
@@ -407,10 +407,10 @@ function apply(o::NADAM, state, x, dx)
407407

408408
@. mt = β[1] * mt + (1 - β[1]) * dx
409409
@. vt = β[2] * vt + (1 - β[2]) * dx^2
410-
dx = @. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
410+
dx = @. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
411411
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
412412

413-
return (mt, vt, βt .* β), dx
413+
return (mt, vt, βt .* β), dx
414414
end
415415

416416
"""
@@ -462,9 +462,9 @@ function apply(o::AdaBelief, state, x, dx)
462462

463463
@. mt = β[1] * mt + (1 - β[1]) * dx
464464
@. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
465-
dx = @. η * mt / (sqrt(st) + ϵ)
465+
dx = @. η * mt / (sqrt(st) + ϵ)
466466

467-
return (mt, st), dx
467+
return (mt, st), dx
468468
end
469469

470470
"""
@@ -485,9 +485,9 @@ init(o::WeightDecay, x::AbstractArray) = nothing
485485
(o::WeightDecay)(state, m, dm) = update(o, state, m, dm)
486486

487487
function apply(o::WeightDecay, state, x, dx)
488-
dx = @. dx + o.wd * x
488+
dx = @. dx + o.wd * x
489489

490-
return state, dx
490+
return state, dx
491491
end
492492

493493
"""

0 commit comments

Comments
 (0)