Skip to content

Commit 4b9e2fb

Browse files
committed
Add doctests in normalise.jl
1 parent 2a0ed9b commit 4b9e2fb

File tree

1 file changed

+131
-7
lines changed

1 file changed

+131
-7
lines changed

src/layers/normalise.jl

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
5555
"""
5656
Dropout(p; dims=:, rng = rng_from_array())
5757
58-
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
58+
Dropout layer. In the forward pass, applies the [`Flux.dropout`](@ref) function on the input.
5959
6060
To apply dropout along certain dimension(s), specify the `dims` keyword.
6161
e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
@@ -65,6 +65,35 @@ Specify `rng` to use a custom RNG instead of the default.
6565
Custom RNGs are only supported on the CPU.
6666
6767
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
68+
69+
# Examples
70+
```jldoctest
71+
julia> m = Chain(Dense(2 => 2), Dropout(1))
72+
Chain(
73+
Dense(2 => 2), # 6 parameters
74+
Dropout(1),
75+
)
76+
77+
julia> Flux.trainmode!(m); # activating the layer without actually training it
78+
79+
julia> m([1, 2]) # drops neurons with a probability of 1
80+
2-element Vector{Float32}:
81+
-0.0
82+
-0.0
83+
84+
julia> m = Chain(Dense(2 => 2), Dropout(0.5))
85+
Chain(
86+
Dense(2 => 2), # 6 parameters
87+
Dropout(0.5),
88+
)
89+
90+
julia> Flux.trainmode!(m); # activating the layer without actually training it
91+
92+
julia> m([1, 2]) # drops neurons with a probability of 0.5
93+
2-element Vector{Float32}:
94+
-4.537827
95+
-0.0
96+
```
6897
"""
6998
mutable struct Dropout{F,D,R<:AbstractRNG}
7099
p::F
@@ -105,6 +134,33 @@ The AlphaDropout layer ensures that mean and variance of activations
105134
remain the same as before.
106135
107136
Does nothing to the input once [`testmode!`](@ref) is true.
137+
138+
# Examples
139+
```jldoctest
140+
julia> x = randn(20,1);
141+
142+
julia> m = Chain(Dense(20 => 10, selu), AlphaDropout(0.5))
143+
Chain(
144+
Dense(20 => 10, selu), # 210 parameters
145+
AlphaDropout{Float64, Random.TaskLocalRNG}(0.5, nothing, Random.TaskLocalRNG()),
146+
)
147+
148+
julia> Flux.trainmode!(m);
149+
150+
julia> y = m(x);
151+
152+
julia> Flux.std(x)
153+
1.097500619939126
154+
155+
julia> Flux.std(y) # maintains the standard deviation of the input
156+
1.1504012188827453
157+
158+
julia> Flux.mean(x) # maintains the mean of the input
159+
-0.3217018554158738
160+
161+
julia> Flux.mean(y)
162+
-0.2526866470385106
163+
```
108164
"""
109165
mutable struct AlphaDropout{F,R<:AbstractRNG}
110166
p::F
@@ -154,6 +210,27 @@ If `affine=true`, it also applies a learnable shift and rescaling
154210
using the [`Scale`](@ref) layer.
155211
156212
See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
213+
214+
# Examples
215+
```jldoctest
216+
julia> xs = rand(3, 3, 3, 2); # a batch of 2 3X3X3 images
217+
218+
julia> m = LayerNorm(3);
219+
220+
julia> y = m(xs);
221+
222+
julia> Flux.std(xs[:, :, :, 1])
223+
0.28713812337208383
224+
225+
julia> Flux.std(y[:, :, :, 1]) # normalises each image (or all channels in an image)
226+
1.018993632693022
227+
228+
julia> Flux.std(xs[:, :, :, 2])
229+
0.22540260537916373
230+
231+
julia> Flux.std(y[:, :, :, 2]) # normalises each image (or all channels in an image)
232+
1.018965249873791
233+
```
157234
"""
158235
struct LayerNorm{F,D,T,N}
159236
λ::F
@@ -256,12 +333,17 @@ Use [`testmode!`](@ref) during inference.
256333
257334
# Examples
258335
```julia
259-
m = Chain(
260-
Dense(28^2 => 64),
261-
BatchNorm(64, relu),
262-
Dense(64 => 10),
263-
BatchNorm(10),
264-
softmax)
336+
julia> xs = rand(3, 3, 3, 2); # a batch of 2 3X3X3 images
337+
338+
julia> Flux.std(xs)
339+
2.6822461565718467
340+
341+
julia> m = BatchNorm(3);
342+
343+
julia> Flux.trainmode!(m); # activating the layer without actually training it
344+
345+
julia> Flux.std(m(xs)) # normalises the complete batch
346+
1.0093209961092855
265347
```
266348
"""
267349
mutable struct BatchNorm{F,V,N,W}
@@ -339,6 +421,27 @@ that will be used to renormalize the input in test phase.
339421
340422
**Warning**: the defaults for `affine` and `track_stats` used to be `true`
341423
in previous Flux versions (< v0.12).
424+
425+
# Examples
426+
```jldoctest
427+
julia> xs = rand(3, 3, 3, 2); # a batch of 2 3X3X3 images
428+
429+
julia> m = InstanceNorm(3);
430+
431+
julia> y = m(xs);
432+
433+
julia> Flux.std(xs[:, :, 1, 1]) # original standard deviation of the first channel of image 1
434+
0.2989802650787384
435+
436+
julia> Flux.std(y[:, :, 1, 1]) # each channel of the batch is normalised
437+
1.0606027381538408
438+
439+
julia> Flux.std(xs[:, :, 2, 2]) # original standard deviation of the second channel of image 2
440+
0.28662705400461197
441+
442+
julia> Flux.std(y[:, :, 2, 2]) # each channel of the batch is normalised
443+
1.06058729821187
444+
```
342445
"""
343446
mutable struct InstanceNorm{F,V,N,W}
344447
λ::F # activation function
@@ -416,6 +519,27 @@ through to learnable per-channel bias `β` and scale `γ` parameters.
416519
417520
If `track_stats=true`, accumulates mean and var statistics in training phase
418521
that will be used to renormalize the input in test phase.
522+
523+
# Examples
524+
```jldoctest
525+
julia> xs = rand(3, 3, 4, 2); # a batch of 2 3X3X4 images
526+
527+
julia> m = GroupNorm(4, 2);
528+
529+
julia> y = m(xs);
530+
531+
julia> Flux.std(xs[:, :, 1:2, 1]) # original standard deviation of the first 2 channels of image 1
532+
0.307588490584917
533+
534+
julia> Flux.std(y[:, :, 1:2, 1]) # normalises channels in groups of 2 (as specified)
535+
1.0289339365431291
536+
537+
julia> Flux.std(xs[:, :, 3:4, 2]) # original standard deviation of the last 2 channels of image 2
538+
0.3111566100804274
539+
540+
julia> Flux.std(y[:, :, 3:4, 2]) # normalises channels in groups of 2 (as specified)
541+
1.0289352493058574
542+
```
419543
"""
420544
mutable struct GroupNorm{F,V,N,W}
421545
G::Int # number of groups

0 commit comments

Comments
 (0)