Skip to content

Commit 51595b7

Browse files
authored
Improve some activation function gradients (#392)
* add a few more activation rules * nicer grad inference test * other testing cleanup * actually test binary rules * only allow x2::Number, and switch to NaN * add non-gradient tests for binary activations * second derivative tests * docstrings * more gradients * faster * notation * one correction
1 parent 07a5284 commit 51595b7

File tree

2 files changed

+267
-143
lines changed

2 files changed

+267
-143
lines changed

src/activations.jl

Lines changed: 136 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Some of activation functions have its wrapper function for GPU in NNlibCUDA.jl.
44
# https://github.com/JuliaGPU/CuArrays.jl/issues/614
55

6-
const ACTIVATIONS = [
6+
ACTIVATIONS = [
77
, :hardσ, :hardtanh, :relu,
88
:leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu,
99
:celu, :softplus, :softsign, :logσ, :logcosh,
@@ -32,6 +32,8 @@ The ascii name `sigmoid` is also exported.
3232
See also [`sigmoid_fast`](@ref).
3333
3434
```
35+
julia> using UnicodePlots
36+
3537
julia> lineplot(sigmoid, -5, 5, height=7)
3638
┌────────────────────────────────────────┐
3739
1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⣀⡠⠤⠖⠒⠒⠋⠉⠉⠉⠉⠉⠉│ σ(x)
@@ -292,9 +294,9 @@ julia> elu(-10f0, 2)
292294
-1.9999092f0
293295
```
294296
"""
295-
elu(x, α=1) = ifelse(x 0, float(x), @fastmath α * (exp(x) - 1))
297+
elu(x, α=1) = ifelse(x 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1))
296298

297-
deriv_elu(Ω, α=1) = ifelse 0, one(Ω), Ω + α)
299+
deriv_elu(Ω, α=1) = ifelse 0, one(Ω), Ω + oftype(Ω, α))
298300

299301
"""
300302
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
@@ -314,6 +316,21 @@ julia> lineplot(gelu, -2, 2, height=7)
314316
└────────────────────────────────────────┘
315317
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
316318
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
319+
320+
julia> lineplot(gelu, -5, 0, height=7);
321+
322+
julia> lineplot!(ans, swish)
323+
┌────────────────────────────────────────┐
324+
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu(x)
325+
│⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x)
326+
│⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│
327+
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│
328+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⣄⠀⠀⠀⠀⠀⢠⡞⠀⠀│
329+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⢄⣀⣀⡤⢣⠃⠀⠀│
330+
-0.2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠓⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠇⠀⠀⠀│
331+
└────────────────────────────────────────┘
332+
⠀-5⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
333+
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
317334
```
318335
"""
319336
function gelu(x)
@@ -353,26 +370,47 @@ julia> lineplot(swish, -2, 2, height=7)
353370
"""
354371
hardswish(x) = x * hardσ(x)
355372
356-
Hard-Swish activation function
357-
See (["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244)).
358-
359-
```
360-
julia> lineplot(hardswish, -2, 2, height = 7)
361-
┌────────────────────────────────────────┐
362-
2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀│ hardswish(x)
363-
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠖⠋⠁│
364-
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⠤⠒⠋⠁⠀⠀⠀⠀│
365-
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
366-
│⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⣤⣤⡤⡧⠴⠶⠯⠭⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤⠤│
367-
│⠒⠲⠤⠤⠤⠤⠤⠤⠖⠒⠒⠒⠒⠊⠉⠉⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
368-
-1 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
369-
└────────────────────────────────────────┘
370-
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
371-
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
373+
Hard-Swish activation function.
374+
See ["Searching for MobileNetV3"](https://arxiv.org/abs/1905.02244).
375+
376+
```
377+
julia> lineplot(hardswish, -2, 5, height = 7)
378+
┌────────────────────────────────────────┐
379+
5 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠔⠒⠉│ hardswish(x)
380+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠔⠒⠉⠁⠀⠀⠀⠀│
381+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠖⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
382+
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
383+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⢀⣀⠤⠖⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
384+
│⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣇⣤⣤⣖⣚⣉⣁⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀│
385+
-1 │⠉⠒⠒⠒⠒⠉⠉⠉⠉⠁⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
386+
└────────────────────────────────────────┘
387+
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀5⠀
388+
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
389+
390+
julia> lineplot(hardswish, -4, 0, height = 7);
391+
392+
julia> lineplot!(ans, swish)
393+
┌────────────────────────────────────────┐
394+
0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⢣⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡜│ hardswish(x)
395+
│⠒⠒⠢⠤⢄⣀⡀⠀⠀⠀⠀⠱⡄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠎⠀│ swish(x)
396+
│⠀⠀⠀⠀⠀⠀⠈⠉⠑⠒⠦⢄⣘⢄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠃⠀⠀│
397+
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠑⡖⠦⢄⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⢔⠏⠁⠀⠀⠀│
398+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠣⣄⠀⠉⠑⠒⠦⠤⢄⣀⣀⣀⣀⡠⠤⠖⣊⠕⠁⠀⠀⠀⠀⠀│
399+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠓⠤⡀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠖⠁⠀⠀⠀⠀⠀⠀⠀│
400+
-0.4 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠉⠒⠢⠤⠤⠔⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
401+
└────────────────────────────────────────┘
402+
⠀-4⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀0⠀
403+
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
404+
405+
julia> hardswish.(-5:5)'
406+
1×11 adjoint(::Vector{Float64}) with eltype Float64:
407+
-0.0 -0.0 -0.0 -0.333333 -0.333333 0.0 0.666667 1.66667 3.0 4.0 5.0
372408
```
373409
"""
374410
@inline hardswish(x) = x * hardσ(x)
375411

412+
deriv_hardswish(x) = ifelse(x < -3, oftf(x,0), ifelse(x > 3, oftf(x,1), x/3 + 1//2))
413+
376414
"""
377415
lisht(x) = x * tanh(x)
378416
@@ -392,6 +430,19 @@ julia> lineplot(lisht, -2, 2, height=7)
392430
└────────────────────────────────────────┘
393431
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
394432
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
433+
434+
julia> lineplot!(ans, logcosh)
435+
┌────────────────────────────────────────┐
436+
2 │⠢⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔│ lisht(x)
437+
│⠀⠈⠑⢦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡤⠊⠁⠀│ logcosh(x)
438+
│⠢⣄⠀⠀⠈⠣⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠁⠀⠀⣀⠔│
439+
f(x) │⠀⠈⠑⠢⣀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⠊⠁⠀⣀⠔⠊⠁⠀│
440+
│⠀⠀⠀⠀⠀⠉⠢⢄⡀⠉⠢⡄⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⢀⠔⠋⠀⡠⠔⠋⠁⠀⠀⠀⠀│
441+
│⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠦⣌⡓⢄⡀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⢀⡠⠖⣁⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀│
442+
0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⠪⠷⣦⣄⣀⣀⣇⣀⣀⣤⠶⠕⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
443+
└────────────────────────────────────────┘
444+
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
445+
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
395446
```
396447
"""
397448
lisht(x) = x * tanh_fast(x)
@@ -406,17 +457,17 @@ Scaled exponential linear units.
406457
See ["Self-Normalizing Neural Networks"](https://arxiv.org/abs/1706.02515).
407458
408459
```
409-
julia> lineplot(selu, -2, 2, height=7)
460+
julia> lineplot(selu, -3, 2, height=7)
410461
┌────────────────────────────────────────┐
411-
3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
412-
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡠⠤⠔⠒│
413-
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⠤⠔⠒⠋⠉⠀⠀⠀⠀⠀│
414-
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⢀⣀⡤⠤⠒⠊⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
415-
│⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⡩⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│
416-
│⠀⠀⠀⠀⠀⠀⠀⣀⣀⣀⡤⠤⠔⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
417-
-2 │⠒⠒⠉⠉⠉⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
462+
3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ selu(x)
463+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠤⠒│
464+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⠤⠖⠊⠉⠀⠀⠀⠀│
465+
f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⣀⡠⠤⠒⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀│
466+
│⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⣉⠭⠛⡏⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉│
467+
│⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⣀⡤⠤⠒⠊⠉⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
468+
-2 │⠤⠤⠖⠒⠒⠒⠒⠒⠒⠒⠉⠉⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
418469
└────────────────────────────────────────┘
419-
⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
470+
⠀-3⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀
420471
⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀
421472
422473
julia> selu(-10f0)
@@ -461,7 +512,9 @@ julia> celu(-10f0)
461512
-0.9999546f0
462513
```
463514
"""
464-
celu(x, α=1) = ifelse(x 0, float(x), α * (exp(x/α) - 1))
515+
celu(x, α=1) = ifelse(x 0, float(x), oftf(x,α) * (exp(x/oftf(x,α)) - 1))
516+
517+
deriv_celu(Ω, α=1) = ifelse> 0, oftf(Ω, 1), Ω / oftf(Ω, α) + 1)
465518

466519
"""
467520
trelu(x, theta=1) = x > theta ? x : 0
@@ -529,6 +582,8 @@ julia> softsign(100f0)
529582
"""
530583
softsign(x) = x / (1 + abs(x))
531584

585+
deriv_softsign(x) = 1 / (1 + abs(x))^2
586+
532587
"""
533588
softplus(x) = log(exp(x) + 1)
534589
@@ -675,9 +730,9 @@ julia> softshrink.((-10f0, 10f0))
675730
(-9.5f0, 9.5f0)
676731
```
677732
"""
678-
function softshrink(x, λ=oftf(x, 0.5))
679-
lo = x - λ
680-
hi = x + λ
733+
function softshrink(x, λ = 0.5)
734+
lo = x - oftf(x, λ)
735+
hi = x + oftf(x, λ)
681736
ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi)
682737
end
683738

@@ -785,58 +840,84 @@ this replacement for some array or element types.
785840

786841
## Define rrules for some activation functions, along with the
787842
## broadcasted rrule activation functions.
788-
## TODO: add to the lists below all activations.
789843

790844
## This is a performance hack specifically for Zygote, because it doesn't handle fused
791845
## broadcasts well; but it generally should be good (or at least harmless) for any AD, as
792846
## it saves ADing the broadcasting machinery.
793847
## Related Issue https://github.com/JuliaDiff/ChainRulesCore.jl/issues/271
794848

795-
UNARY_ACTS = [ # f, df
796-
(:relu, :(x > 0)),
797-
(:hardtanh, :(-1 < x < 1)),
798-
(:selu, :(deriv_selu(Ω))),
849+
## TODO: add to the lists below all activations.
850+
851+
UNARY_ACTS = [ # f, dfdx
852+
## In the same order as above!
799853
(, :(conj* (1 - Ω)))),
854+
(:hardσ, :(ifelse((Ω>0)&<1), 1//6, 1//1))),
855+
(:logσ, :(sigmoid_fast(-x))),
856+
(:hardtanh, :((Ω>-1) &<1))),
857+
(:relu, :(Ω > 0)),
858+
(:leakyrelu, :(ifelse> 0, 1//1, 1//100))),
859+
(:relu6, :((Ω>0) &<6))),
860+
# rrelu is random, can't write a rule.
800861
(:elu, :(deriv_elu(Ω))),
801-
(:softplus, :(σ(x))),
802-
862+
# gelu
863+
(:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))),
864+
(:hardswish, :(deriv_hardswish(x))),
865+
# lisht
866+
(:selu, :(deriv_selu(Ω))),
867+
(:celu, :(deriv_celu(Ω))),
868+
(:trelu, :(Ω > 0)),
869+
(:softsign, :(deriv_softsign(x))),
870+
(:softplus, :(sigmoid_fast(x))),
871+
# (:softplus, :(1 - @fastmath exp(-Ω))), # slightly faster, check accuracy?
872+
# logcosh
873+
# mish
874+
(:tanhshrink, :((x - Ω)^2)),
875+
(:softshrink, :(Ω != 0)),
876+
## Fast variants are the same!
803877
(:tanh_fast, :(conj(1 - Ω^2))),
804878
(:sigmoid_fast, :(conj* (1 - Ω)))),
805-
]
879+
]
806880

807-
for (f, df) in UNARY_ACTS
808-
@eval @scalar_rule($f(x), $df)
881+
for (f, dfdx) in UNARY_ACTS
882+
@eval @scalar_rule($f(x), $dfdx)
809883

810884
pullback = Symbol(:broadcasted_, f, :_pullback)
811885
@eval function rrule(::typeof(broadcasted),
812886
::typeof($f), x::Numeric)
813887
Ω = $f.(x)
814-
function $pullback(Δ)
888+
function $pullback()
815889
x_thunk = InplaceableThunk(
816-
dx -> @.(dx += Δ * $df),
817-
@thunk @.(Δ * $df)
890+
dx -> @.(dx += * $dfdx),
891+
@thunk @.( * $dfdx)
818892
)
819893
NoTangent(), NoTangent(), x_thunk
820894
end
821895
return Ω, $pullback
822896
end
823897
end
824898

825-
BINARY_ACTS = [ # f, df1, df2
826-
(:elu, :(deriv_elu(Ω, x2)), :(NoTangent())), # TODO use real deriv instead of DNE
827-
]
899+
# NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant"
900+
NO_ACT_GRAD = NaN ## Still reminds you not to use this, but is perhaps more GPU friendly.
828901

829-
for (f, df1, df2) in BINARY_ACTS
830-
@eval @scalar_rule($f(x1, x2), ($df1, $df2))
902+
BINARY_ACTS = [ # f, dfdx1, dfdx2
903+
## In the same order as above!
904+
(:leakyrelu, :(ifelse> 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD),
905+
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD),
906+
(:celu, :(deriv_celu(Ω, x2)), NO_ACT_GRAD),
907+
(:trelu, :(Ω > 0), ZeroTangent()),
908+
(:softshrink, :(Ω != 0), NO_ACT_GRAD),
909+
]
831910

832-
pullback = Symbol(:broadcasted_, f, :_pullback)
911+
for (f, dfdx1, dfdx2) in BINARY_ACTS
912+
@eval @scalar_rule($f(x1, x2), ($dfdx1, $dfdx2))
913+
914+
pullback = Symbol(:broadcasted_, f, :_pullback_2arg)
833915
@eval function rrule(::typeof(broadcasted),
834916
::typeof($f),
835-
x1::Numeric, x2::Numeric)
917+
x1::Numeric, x2::Number)
836918
Ω = $f.(x1, x2)
837-
function $pullback(Δ)
838-
NoTangent(), NoTangent(), @.* $df1), @.* $df2)
839-
end
919+
## Allowing x2::Array would allow size(Ω) != size(x1), which is not handled here:
920+
$pullback(dΩ) = (NoTangent(), NoTangent(), @.(dΩ * $dfdx1), NO_ACT_GRAD)
840921
return Ω, $pullback
841922
end
842923
end

0 commit comments

Comments
 (0)