Skip to content

Commit 0683976

Browse files
Change cells' return to out, state (#2551)
1 parent 6041cf5 commit 0683976

File tree

6 files changed

+123
-136
lines changed

6 files changed

+123
-136
lines changed

NEWS.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5-
## v0.15.3
6-
* Add `WeightNorm` normalization layer.
5+
## v0.16.0 (15 December 2025)
6+
This release has a single **breaking change**:
77

8-
## v0.15.0 (December 2024)
8+
- The recurrent cells `RNNCell`, `LSTMCell`, and `GRUCell` forward has been changed to
9+
$y_t, state_t = cell(x_t, state_{t-1})$. Previously, it was $state_t = cell(x_t, state_{t-1})$.
10+
11+
Other highlights include:
12+
* Added `WeightNorm` normalization layer.
13+
* Added `Recurrence` layer, turning a recurrent layer into a layer processing the entire sequence at once.
14+
15+
## v0.15.0 (5 December 2024)
916
This release includes two **breaking changes**:
1017
- The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details.
1118
- Flux now defines and exports its own gradient function. Consequently, using gradient in an unqualified manner (e.g., after `using Flux, Zygote`) could result in an ambiguity error.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.15.2"
3+
version = "0.16.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/guide/models/recurrence.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ b = zeros(Float32, output_size)
2121

2222
function rnn_cell(x, h)
2323
h = tanh.(Wxh * x .+ Whh * h .+ b)
24-
return h
24+
return h, h
2525
end
2626

2727
seq_len = 3
@@ -33,14 +33,14 @@ h0 = zeros(Float32, output_size)
3333
y = []
3434
ht = h0
3535
for xt in x
36-
ht = rnn_cell(xt, ht)
37-
y = [y; [ht]] # concatenate in non-mutating (AD friendly) way
36+
yt, ht = rnn_cell(xt, ht)
37+
y = [y; [yt]] # concatenate in non-mutating (AD friendly) way
3838
end
3939
```
4040

4141
Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`.
42-
43-
The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model.
42+
The result of the forward pass at each time step, is a tuple contening the output `yt` and the updated state `ht`. The updated state is used as an input in next iteration. In the simple case of a vanilla RNN, the
43+
output and the state are the same. In more complex cells, such as `LSTMCell`, the state can contain multiple arrays.
4444

4545
There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with:
4646

@@ -58,8 +58,8 @@ rnn_cell = Flux.RNNCell(input_size => output_size)
5858
y = []
5959
ht = h0
6060
for xt in x
61-
ht = rnn_cell(xt, ht)
62-
y = [y; [ht]]
61+
yt, ht = rnn_cell(xt, ht)
62+
y = [y; [yt]]
6363
end
6464
```
6565
The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression.
@@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D}
7878
end
7979

8080
# we choose to not train the initial hidden state
81-
Flux.@layer RecurrentCellModel trainable=(cell,dense)
81+
Flux.@layer RecurrentCellModel trainable=(cell, dense)
8282

8383
function RecurrentCellModel(input_size::Int, hidden_size::Int)
8484
return RecurrentCellModel(
@@ -91,8 +91,8 @@ function (m::RecurrentCellModel)(x)
9191
z = []
9292
ht = m.h0
9393
for xt in x
94-
ht = m.cell(xt, ht)
95-
z = [z; [ht]]
94+
yt, ht = m.cell(xt, ht)
95+
z = [z; [yt]]
9696
end
9797
z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
9898
= m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
@@ -109,7 +109,6 @@ using Optimisers: AdamW
109109

110110
function loss(model, x, y)
111111
= model(x)
112-
y = stack(y, dims=2)
113112
return Flux.mse(ŷ, y)
114113
end
115114

@@ -123,7 +122,7 @@ model = RecurrentCellModel(input_size, 5)
123122
opt_state = Flux.setup(AdamW(1e-3), model)
124123

125124
# compute the gradient and update the model
126-
g = gradient(m -> loss(m, x, y),model)[1]
125+
g = gradient(m -> loss(m, x, y), model)[1]
127126
Flux.update!(opt_state, model, g)
128127
```
129128

src/layers/recurrent.jl

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
out_from_state(state) = state
2-
out_from_state(state::Tuple) = state[1]
3-
41
function scan(cell, x, state)
52
y = []
63
for x_t in eachslice(x, dims = 2)
7-
state = cell(x_t, state)
8-
out = out_from_state(state)
9-
y = vcat(y, [out])
4+
yt, state = cell(x_t, state)
5+
y = vcat(y, [yt])
106
end
117
return stack(y, dims = 2)
128
end
@@ -85,7 +81,6 @@ In the forward pass, implements the function
8581
```math
8682
h^\prime = \sigma(W_i x + W_h h + b)
8783
```
88-
and returns `h'`.
8984
9085
See [`RNN`](@ref) for a layer that processes entire sequences.
9186
@@ -107,6 +102,9 @@ The arguments of the forward pass are:
107102
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
108103
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
109104
105+
Returns a tuple `(output, state)`, where both elements are given by the updated state `h'`,
106+
a tensor of size `out` or `out x batch_size`.
107+
110108
# Examples
111109
112110
```julia
@@ -123,10 +121,10 @@ h = zeros(Float32, 5)
123121
ŷ = []
124122
125123
for x_t in x
126-
h = r(x_t, h)
127-
ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation
128-
# is not automatic differentiation friendly yet.
129-
# Can use `y = vcat(y, [h])` as an alternative.
124+
yt, h = r(x_t, h)
125+
ŷ = [ŷ..., yt] # Cannot use `push!(ŷ, h)` here since mutation
126+
# is not automatic differentiation friendly yet.
127+
# Can use `y = vcat(y, [h])` as an alternative.
130128
end
131129
132130
h # The final hidden state
@@ -155,40 +153,37 @@ using Flux
155153
rnn = RNNCell(10 => 20)
156154
157155
# Get the initial hidden state
158-
h0 = initialstates(rnn)
156+
state = initialstates(rnn)
159157
160158
# Get some input data
161159
x = rand(Float32, 10)
162160
163161
# Run forward
164-
res = rnn(x, h0)
162+
out, state = rnn(x, state)
165163
```
166164
"""
167165
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
168166

169167
function RNNCell(
170-
(in, out)::Pair,
171-
σ = tanh;
172-
init_kernel = glorot_uniform,
173-
init_recurrent_kernel = glorot_uniform,
174-
bias = true,
175-
)
168+
(in, out)::Pair,
169+
σ = tanh;
170+
init_kernel = glorot_uniform,
171+
init_recurrent_kernel = glorot_uniform,
172+
bias = true,
173+
)
176174
Wi = init_kernel(out, in)
177175
Wh = init_recurrent_kernel(out, out)
178176
b = create_bias(Wi, bias, size(Wi, 1))
179177
return RNNCell(σ, Wi, Wh, b)
180178
end
181179

182-
function (rnn::RNNCell)(x::AbstractVecOrMat)
183-
state = initialstates(rnn)
184-
return rnn(x, state)
185-
end
180+
(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn))
186181

187182
function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
188183
_size_check(m, x, 1 => size(m.Wi, 2))
189184
σ = NNlib.fast_act(m.σ, x)
190185
h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias)
191-
return h
186+
return h, h
192187
end
193188

194189
function Base.show(io::IO, m::RNNCell)
@@ -278,10 +273,7 @@ function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
278273
return RNN(cell)
279274
end
280275

281-
function (rnn::RNN)(x::AbstractArray)
282-
state = initialstates(rnn)
283-
return rnn(x, state)
284-
end
276+
(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))
285277

286278
function (m::RNN)(x::AbstractArray, h)
287279
@assert ndims(x) == 2 || ndims(x) == 3
@@ -315,7 +307,6 @@ o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
315307
h_t = o_t \odot \tanh(c_t)
316308
```
317309
318-
The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step.
319310
See also [`LSTM`](@ref) for a layer that processes entire sequences.
320311
321312
# Arguments
@@ -336,7 +327,8 @@ The arguments of the forward pass are:
336327
They should be vectors of size `out` or matrices of size `out x batch_size`.
337328
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
338329
339-
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
330+
Returns a tuple `(output, state)`, where `output = h'` is the new hidden state and `state = (h', c')` is the new hidden and cell states.
331+
These are tensors of size `out` or `out x batch_size`.
340332
341333
# Examples
342334
@@ -350,9 +342,9 @@ julia> c = zeros(Float32, 5); # cell state
350342
351343
julia> x = rand(Float32, 3, 4); # in x batch_size
352344
353-
julia> h′, c′ = l(x, (h, c));
345+
julia> y, (h′, c′) = l(x, (h, c));
354346
355-
julia> size(h′) # out x batch_size
347+
julia> size(y) # out x batch_size
356348
(5, 4)
357349
```
358350
"""
@@ -389,9 +381,9 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
389381
b = m.bias
390382
g = m.Wi * x .+ m.Wh * h .+ b
391383
input, forget, cell, output = chunk(g, 4; dims = 1)
392-
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
393-
h = @. sigmoid_fast(output) * tanh_fast(c)
394-
return h′, c′
384+
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
385+
h = @. sigmoid_fast(output) * tanh_fast(c)
386+
return h, (h, c)
395387
end
396388

397389
Base.show(io::IO, m::LSTMCell) =
@@ -522,7 +514,8 @@ The arguments of the forward pass are:
522514
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
523515
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
524516
525-
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
517+
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
518+
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
526519
527520
# Examples
528521
@@ -534,7 +527,7 @@ julia> h = zeros(Float32, 5); # hidden state
534527
535528
julia> x = rand(Float32, 3, 4); # in x batch_size
536529
537-
julia> h′ = g(x, h);
530+
julia> y, h = g(x, h);
538531
```
539532
"""
540533
struct GRUCell{I, H, V}
@@ -577,8 +570,8 @@ function (m::GRUCell)(x::AbstractVecOrMat, h)
577570
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
578571
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
579572
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
580-
h = @. (1 - z) *+ z * h
581-
return h
573+
h = @. (1 - z) *+ z * h
574+
return h, h
582575
end
583576

584577
Base.show(io::IO, m::GRUCell) =
@@ -693,7 +686,8 @@ The arguments of the forward pass are:
693686
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
694687
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
695688
696-
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
689+
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
690+
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
697691
"""
698692
struct GRUv3Cell{I, H, V, HH}
699693
Wi::I
@@ -736,8 +730,8 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
736730
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
737731
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
738732
= tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3])
739-
h = @. (1 - z) *+ z * h
740-
return h
733+
h = @. (1 - z) *+ z * h
734+
return h, h
741735
end
742736

743737
Base.show(io::IO, m::GRUv3Cell) =

test/ext_common/recurrent_gpu_ad.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
out_from_state(state::Tuple) = state[1]
2-
out_from_state(state) = state
1+
cell_loss(cell, x, state) = mean(cell(x, state)[1])
32

43
function recurrent_cell_loss(cell, seq, state)
54
out = []
65
for xt in seq
7-
state = cell(xt, state)
8-
yt = out_from_state(state)
6+
yt, state = cell(xt, state)
97
out = vcat(out, [yt])
108
end
119
return mean(stack(out, dims = 2))
@@ -18,7 +16,8 @@ end
1816
h = zeros(Float32, d_out)
1917
# Single Step
2018
@test test_gradients(r, x[1], h; test_gpu=true,
21-
compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
19+
compare_finite_diff=false,
20+
loss=cell_loss) broken = :rnncell_single BROKEN_TESTS
2221
# Multiple Steps
2322
@test test_gradients(r, x, h; test_gpu=true,
2423
compare_finite_diff=false,
@@ -51,7 +50,7 @@ end
5150
c = zeros(Float32, d_out)
5251
# Single Step
5352
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
54-
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
53+
loss = cell_loss) broken = :lstmcell_single BROKEN_TESTS
5554
# Multiple Steps
5655
@test test_gradients(cell, x, (h, c); test_gpu=true,
5756
compare_finite_diff = false,
@@ -84,7 +83,9 @@ end
8483
r = GRUCell(d_in => d_out)
8584
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
8685
h = zeros(Float32, d_out)
87-
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
86+
@test test_gradients(r, x[1], h; test_gpu=true,
87+
compare_finite_diff=false,
88+
loss = cell_loss) broken = :grucell_single BROKEN_TESTS
8889
@test test_gradients(r, x, h; test_gpu=true,
8990
compare_finite_diff = false,
9091
loss = recurrent_cell_loss) broken = :grucell_multiple BROKEN_TESTS
@@ -116,7 +117,8 @@ end
116117
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
117118
h = zeros(Float32, d_out)
118119
@test test_gradients(r, x[1], h; test_gpu=true,
119-
compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
120+
compare_finite_diff=false,
121+
loss=cell_loss) broken = :gruv3cell_single BROKEN_TESTS
120122
@test test_gradients(r, x, h; test_gpu=true,
121123
compare_finite_diff=false,
122124
loss = recurrent_cell_loss) broken = :gruv3cell_multiple BROKEN_TESTS

0 commit comments

Comments
 (0)