-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Implementations of 6 different versions of the forward pass, one for each of the permutation of D(input dimension), T(time duration) and B(batch size):
"""
D × T × B
"""
function fdtb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
"""
D × B × T
"""
function fdbt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = permutedims(m.listen(Xs), [1,3,2])
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, :, axes(Hs,3)))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
αᵢs = softmax(hcat(Eᵢs...); dims=2)
# αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
"""
T × D × B
"""
function ftdb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
Hs = permutedims(m.listen(Xs), [2,1,3])
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, :,1, batch_size) .* Hs; dims=1); dims=1)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
"""
B × T × D
"""
function fbtd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
Hs = permutedims(m.listen(Xs), [3,2,1])
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
αᵢs = softmax(hcat(Eᵢs...); dims=2)
# αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=2); dims=2))
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
"""
T × B × D
"""
function ftbd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
Hs = permutedims(m.listen(Xs), [2,3,1])
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=1); dims=1))
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
"""
B × D × T
"""
function fbdt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
Hs = permutedims(m.listen(Xs), [3,1,2])
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
αᵢs = softmax(hcat(Eᵢs...); dims=2)
# αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = permutedims(dropdims(sum(reshape(αᵢs, batch_size, 1, :) .* Hs; dims=3); dims=3))
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function gfdtb(m, Xs, θ)
gradient(θ) do
sum(sum(fdtb(m, Xs)))
end
end
function gfdbt(m, Xs, θ)
gradient(θ) do
sum(sum(fdbt(m, Xs)))
end
end
function gftdb(m, Xs, θ)
gradient(θ) do
sum(sum(ftdb(m, Xs)))
end
end
function gfbtd(m, Xs, θ)
gradient(θ) do
sum(sum(fbtd(m, Xs)))
end
end
function gftbd(m, Xs, θ)
gradient(θ) do
sum(sum(ftbd(m, Xs)))
end
end
function gfbdt(m, Xs, θ)
gradient(θ) do
sum(sum(fbdt(m, Xs)))
end
endWas used smallish size neural net with the following dimensions
encoder_dims = (
blstm = (in = 39, out = 64),
pblstms_out = (64, 64, 64)
)
attention_dim = 128
decoder_out_dims = (128, 64)
m = LAS(encoder_dims, attention_dim, decoder_out_dims, out_dim)
θ = Flux.params(m)
using BenchmarkToolsResults for xs = last(Xs_train); Xs = vecofmats2tensor(xs):
julia> reset!(m);
julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.94 GiB
allocs estimate: 306948
--------------
minimum time: 3.294 s (11.02% GC)
median time: 3.295 s (11.20% GC)
mean time: 3.295 s (11.20% GC)
maximum time: 3.296 s (11.39% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.94 GiB
allocs estimate: 318347
--------------
minimum time: 3.485 s (10.91% GC)
median time: 3.514 s (10.76% GC)
mean time: 3.514 s (10.76% GC)
maximum time: 3.543 s (10.61% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 3.54 GiB
allocs estimate: 367033
--------------
minimum time: 4.517 s (10.70% GC)
median time: 4.523 s (10.57% GC)
mean time: 4.523 s (10.57% GC)
maximum time: 4.529 s (10.44% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 3.55 GiB
allocs estimate: 386793
--------------
minimum time: 4.498 s (10.76% GC)
median time: 4.501 s (10.51% GC)
mean time: 4.501 s (10.51% GC)
maximum time: 4.504 s (10.26% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 3.55 GiB
allocs estimate: 366273
--------------
minimum time: 4.461 s (10.82% GC)
median time: 4.469 s (10.95% GC)
mean time: 4.469 s (10.95% GC)
maximum time: 4.477 s (11.07% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 3.55 GiB
allocs estimate: 387553
--------------
minimum time: 5.083 s (9.19% GC)
median time: 5.083 s (9.19% GC)
mean time: 5.083 s (9.19% GC)
maximum time: 5.083 s (9.19% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 17.50 GiB
allocs estimate: 2662077
--------------
minimum time: 30.478 s (64.75% GC)
median time: 30.478 s (64.75% GC)
mean time: 30.478 s (64.75% GC)
maximum time: 30.478 s (64.75% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 17.51 GiB
allocs estimate: 2689455
--------------
minimum time: 30.562 s (64.84% GC)
median time: 30.562 s (64.84% GC)
mean time: 30.562 s (64.84% GC)
maximum time: 30.562 s (64.84% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 18.29 GiB
allocs estimate: 2968508
--------------
minimum time: 28.648 s (57.85% GC)
median time: 28.648 s (57.85% GC)
mean time: 28.648 s (57.85% GC)
maximum time: 28.648 s (57.85% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 18.32 GiB
allocs estimate: 3001183
--------------
minimum time: 28.857 s (57.49% GC)
median time: 28.857 s (57.49% GC)
mean time: 28.857 s (57.49% GC)
maximum time: 28.857 s (57.49% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 18.32 GiB
allocs estimate: 2964703
--------------
minimum time: 28.671 s (57.99% GC)
median time: 28.671 s (57.99% GC)
mean time: 28.671 s (57.99% GC)
maximum time: 28.671 s (57.99% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 18.32 GiB
allocs estimate: 3008029
--------------
minimum time: 28.963 s (57.72% GC)
median time: 28.963 s (57.72% GC)
mean time: 28.963 s (57.72% GC)
maximum time: 28.963 s (57.72% GC)
--------------
samples: 1
evals/sample: 1Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):
julia> reset!(m);
julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.72 GiB
allocs estimate: 54890
--------------
minimum time: 2.456 s (7.91% GC)
median time: 2.509 s (9.16% GC)
mean time: 2.504 s (8.79% GC)
maximum time: 2.548 s (9.26% GC)
--------------
samples: 3
evals/sample: 1
julia> reset!(m);
julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.72 GiB
allocs estimate: 57409
--------------
minimum time: 2.532 s (8.33% GC)
median time: 2.604 s (8.87% GC)
mean time: 2.604 s (8.87% GC)
maximum time: 2.676 s (9.38% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.28 GiB
allocs estimate: 74143
--------------
minimum time: 3.719 s (7.99% GC)
median time: 3.754 s (8.27% GC)
mean time: 3.754 s (8.27% GC)
maximum time: 3.789 s (8.54% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.29 GiB
allocs estimate: 78511
--------------
minimum time: 3.604 s (8.44% GC)
median time: 3.623 s (8.55% GC)
mean time: 3.623 s (8.55% GC)
maximum time: 3.643 s (8.66% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.29 GiB
allocs estimate: 73975
--------------
minimum time: 3.684 s (8.50% GC)
median time: 3.710 s (8.67% GC)
mean time: 3.710 s (8.67% GC)
maximum time: 3.736 s (8.84% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 2.29 GiB
allocs estimate: 78679
--------------
minimum time: 3.624 s (8.84% GC)
median time: 3.636 s (8.83% GC)
mean time: 3.636 s (8.83% GC)
maximum time: 3.647 s (8.81% GC)
--------------
samples: 2
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 5.75 GiB
allocs estimate: 362721
--------------
minimum time: 8.457 s (34.28% GC)
median time: 8.457 s (34.28% GC)
mean time: 8.457 s (34.28% GC)
maximum time: 8.457 s (34.28% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 5.75 GiB
allocs estimate: 368787
--------------
minimum time: 8.571 s (33.81% GC)
median time: 8.571 s (33.81% GC)
mean time: 8.571 s (33.81% GC)
maximum time: 8.571 s (33.81% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 6.48 GiB
allocs estimate: 440883
--------------
minimum time: 10.769 s (34.73% GC)
median time: 10.769 s (34.73% GC)
mean time: 10.769 s (34.73% GC)
maximum time: 10.769 s (34.73% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 6.50 GiB
allocs estimate: 448102
--------------
minimum time: 10.603 s (35.30% GC)
median time: 10.603 s (35.30% GC)
mean time: 10.603 s (35.30% GC)
maximum time: 10.603 s (35.30% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 6.50 GiB
allocs estimate: 440038
--------------
minimum time: 10.768 s (34.75% GC)
median time: 10.768 s (34.75% GC)
mean time: 10.768 s (34.75% GC)
maximum time: 10.768 s (34.75% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m);
julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 6.50 GiB
allocs estimate: 449619
--------------
minimum time: 10.641 s (35.38% GC)
median time: 10.641 s (35.38% GC)
mean time: 10.641 s (35.38% GC)
maximum time: 10.641 s (35.38% GC)
--------------
samples: 1
evals/sample: 1Conclusion: D × T × B and D × B × T orderings seem to be the most efficient ones, although not much difference between all of the versions and as T grows all computations are dominated by garbage collection and speed differences almost vanish.
Metadata
Metadata
Assignees
Labels
No labels