Skip to content

Commit b47a2ef

Browse files
docstring, type annotation (#170)
* Missing comma in Base.show for ConvDims * Simplify definitions * Add default parameters to description * Update src/activation.jl Co-Authored-By: Carlo Lucibello <[email protected]> * Trim spaces Co-authored-by: Carlo Lucibello <[email protected]>
1 parent a428988 commit b47a2ef

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/activation.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ relu(x::Real) = max(zero(x), x)
4545

4646

4747
"""
48-
leakyrelu(x) = max(0.01x, x)
48+
leakyrelu(x, a=0.01) = max(a*x, x)
4949
5050
Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
5151
activation function.
@@ -54,40 +54,41 @@ You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5454
leakyrelu(x::Real, a = oftype(x / 1, 0.01)) = max(a * x, x / one(x))
5555

5656
"""
57-
relu6(x) = min(max(0, x),6)
57+
relu6(x) = min(max(0, x), 6)
5858
5959
[Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
60-
activation function.
60+
activation function capped at 6.
61+
See [Convolutional Deep Belief Networks on CIFAR-10](http://www.cs.utoronto.ca/%7Ekriz/conv-cifar10-aug2010.pdf)
6162
"""
62-
relu6(x::Real) = min(relu(x), one(x)*oftype(x, 6))
63+
relu6(x::Real) = min(relu(x), oftype(x, 6))
6364

6465
"""
65-
rrelu(x) = max(ax, x)
66+
rrelu(x, l=1/8, u=1/3) = max(a*x, x)
6667
67-
a = randomly sampled from uniform distribution U(l,u)
68+
a = randomly sampled from uniform distribution U(l, u)
6869
6970
Randomized Leaky [Rectified Linear Unit](https://arxiv.org/pdf/1505.00853.pdf)
7071
activation function.
7172
You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.
7273
"""
7374
function rrelu(x::Real, l::Real = 1 / 8.0, u::Real = 1 / 3.0)
74-
a = oftype(x /1, (u - l) * rand() + l)
75+
a = oftype(x / 1, (u - l) * rand() + l)
7576
return leakyrelu(x, a)
7677
end
7778

7879
"""
79-
elu(x, α = 1) =
80+
elu(x, α=1) =
8081
x > 0 ? x : α * (exp(x) - 1)
8182
8283
Exponential Linear Unit activation function.
8384
See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289).
8485
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
8586
"""
86-
elu(x, α = one(x)) = ifelse(x 0, x / one(x), α * (exp(x) - one(x)))
87+
elu(x::Real, α = one(x)) = ifelse(x 0, x / one(x), α * (exp(x) - one(x)))
8788

8889

8990
"""
90-
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
91+
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
9192
9293
[Gaussian Error Linear Unit](https://arxiv.org/pdf/1606.08415.pdf)
9394
activation function.
@@ -125,7 +126,8 @@ function selu(x::Real)
125126
end
126127

127128
"""
128-
celu(x) = (x ≥ 0 ? x : α * (exp(x/α) - 1))
129+
celu(x, α=1) =
130+
(x ≥ 0 ? x : α * (exp(x/α) - 1))
129131
130132
Continuously Differentiable Exponential Linear Units
131133
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf).
@@ -155,7 +157,7 @@ softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
155157
156158
Return `log(cosh(x))` which is computed in a numerically stable way.
157159
"""
158-
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))
160+
logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2))
159161

160162

161163
"""
@@ -174,7 +176,8 @@ See [Tanhshrink Activation Function](https://www.gabormelli.com/RKB/Tanhshrink_A
174176
tanhshrink(x::Real) = x - tanh(x)
175177

176178
"""
177-
softshrink = (x ≥ λ ? x-λ : (-λ ≥ x ? x+λ : 0))
179+
softshrink(x, λ=0.5) =
180+
(x ≥ λ ? x - λ : (-λ ≥ x ? x + λ : 0))
178181
179182
See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function)
180183
"""

src/dim_helpers/ConvDims.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,5 @@ function Base.show(io::IO, cdims::C) where {C <: ConvDims}
131131
P = padding(cdims)
132132
D = dilation(cdims)
133133
F = flipkernel(cdims)
134-
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S pad: $P, dil: $D, flip: $F")
134+
print(io, "$(basetype(C)): $I * $K -> $O, stride: $S, pad: $P, dil: $D, flip: $F")
135135
end

0 commit comments

Comments
 (0)