Skip to content

Commit 6278916

Browse files
Revert "remove flatten"
This reverts commit d91c203.
1 parent d91c203 commit 6278916

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

perf/vgg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function vgg16()
3838
Conv((3, 3), 512 => 512, relu, pad=(1, 1), stride=(1, 1)),
3939
BatchNorm(512),
4040
MaxPool((2,2)),
41-
Flux.flatten,
41+
flatten,
4242
Dense(512, 4096, relu),
4343
Dropout(0.5),
4444
Dense(4096, 4096, relu),

src/layers/stateless.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,31 @@
1+
"""
2+
flatten(x::AbstractArray)
3+
4+
Reshape arbitrarly-shaped input into a matrix-shaped output,
5+
preserving the size of the last dimension.
6+
7+
See also [`unsqueeze`](@ref).
8+
9+
# Examples
10+
```jldoctest
11+
julia> rand(3,4,5) |> Flux.flatten |> size
12+
(12, 5)
13+
14+
julia> xs = rand(Float32, 10,10,3,7);
15+
16+
julia> m = Chain(Conv((3,3), 3 => 4, pad=1), Flux.flatten, Dense(400 => 33));
17+
18+
julia> xs |> m[1] |> size
19+
(10, 10, 4, 7)
20+
21+
julia> xs |> m |> size
22+
(33, 7)
23+
```
24+
"""
25+
function flatten(x::AbstractArray)
26+
return reshape(x, :, size(x)[end])
27+
end
28+
129
"""
230
normalise(x; dims=ndims(x), ϵ=1e-5)
331

0 commit comments

Comments
 (0)