Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/aggregations/aggregation_stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,18 @@ AggregationStack(fs::AbstractAggregation...) = AggregationStack(fs)

Flux.@layer :ignore AggregationStack

function (a::AggregationStack)(x::Maybe{AbstractArray}, bags::AbstractBags, args...)
reduce(vcat, (f(x, bags, args...) for f in a.fs))
# function (a::AggregationStack)(x::Maybe{AbstractArray}, bags::AbstractBags, args...)
# reduce(vcat, (f(x, bags, args...) for f in a.fs))
# end

@generated function (a::AggregationStack{T})(x::Maybe{AbstractArray}, bags::AbstractBags, args...) where {T<:Tuple}
l = T.parameters |> length
chs = map(1:l) do i
:(a.fs[$i](x, bags, args...))
end
quote
vcat($(chs...))
end
end

Flux.@forward AggregationStack.fs Base.getindex, Base.firstindex, Base.lastindex, Base.first,
Expand Down
2 changes: 1 addition & 1 deletion src/aggregations/aggregations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ abstract type AbstractAggregation end
@inline _weightsum(ws::AbstractVector, i) = ws[i]

# more stable definitions for r_map and p_map
ChainRulesCore.rrule(::typeof(softplus), x) = softplus.(x), Δ -> (NoTangent(), Δ .* σ.(x))
ChainRulesCore.rrule(::typeof(softplus), x) = softplus.(x), Δ -> (NoTangent(), unthunk(Δ) .* σ.(x))

# our definition of type min for Maybe{...} types
_typemin(t::Type) = typemin(t)
Expand Down
4 changes: 2 additions & 2 deletions src/aggregations/segmented_lse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ function ChainRulesCore.rrule(::typeof(segmented_lse_forw),
x::AbstractMatrix, ψ::AbstractVector, r::AbstractVector, bags::AbstractBags)
M = _lse_precomp(x, r, bags)
y = _segmented_lse_norm(x, ψ, r, bags, M)
grad = Δ -> (NoTangent(), segmented_lse_back(Δ, y, x, ψ, r, bags, M)...)
grad = Δ -> (NoTangent(), segmented_lse_back(unthunk(Δ), y, x, ψ, r, bags, M)...)
y, grad
end

function ChainRulesCore.rrule(::typeof(segmented_lse_forw),
x::Missing, ψ::AbstractVector, r::AbstractVector, bags::AbstractBags)
y = segmented_lse_forw(x, ψ, r, bags)
grad = Δ -> (NoTangent(), segmented_lse_back(Δ, x, ψ, bags)...)
grad = Δ -> (NoTangent(), segmented_lse_back(unthunk(Δ), x, ψ, bags)...)
y, grad
end
2 changes: 1 addition & 1 deletion src/aggregations/segmented_max.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ end

function ChainRulesCore.rrule(::typeof(segmented_max_forw), args...)
y = segmented_max_forw(args...)
grad = Δ -> (NoTangent(), segmented_max_back(Δ, y, args...)...)
grad = Δ -> (NoTangent(), segmented_max_back(unthunk(Δ), y, args...)...)
y, grad
end
2 changes: 1 addition & 1 deletion src/aggregations/segmented_mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ end

function ChainRulesCore.rrule(::typeof(segmented_mean_forw), args...)
y = segmented_mean_forw(args...)
grad = Δ -> (NoTangent(), segmented_mean_back(Δ, y, args...)...)
grad = Δ -> (NoTangent(), segmented_mean_back(unthunk(Δ), y, args...)...)
y, grad
end
4 changes: 2 additions & 2 deletions src/aggregations/segmented_pnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ end
function ChainRulesCore.rrule(::typeof(segmented_pnorm_forw), a::AbstractMatrix, ψ, p, bags, w)
M = _pnorm_precomp(a, bags)
y = _segmented_pnorm_norm(a, ψ, p, bags, w, M)
grad = Δ -> (NoTangent(), segmented_pnorm_back(Δ, y, a, ψ, p, bags, w, M)...)
grad = Δ -> (NoTangent(), segmented_pnorm_back(unthunk(Δ), y, a, ψ, p, bags, w, M)...)
y, grad
end

function ChainRulesCore.rrule(::typeof(segmented_pnorm_forw), a::Missing, ψ, p, bags, w)
y = segmented_pnorm_forw(a, ψ, p, bags, w)
grad = Δ -> (NoTangent(), segmented_pnorm_back(Δ, y, ψ, bags)...)
grad = Δ -> (NoTangent(), segmented_pnorm_back(unthunk(Δ), y, ψ, bags)...)
y, grad
end
3 changes: 2 additions & 1 deletion src/aggregations/segmented_sum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ function segmented_sum_forw(x::AbstractMatrix, ψ::AbstractVector, bags::Abstrac
end

function segmented_sum_back(Δ, y, x, ψ, bags, w)
Δ = unthunk(Δ)
dx = zero(x)
dψ = zero(ψ)
dw = isnothing(w) ? ZeroTangent() : zero(w)
Expand Down Expand Up @@ -96,6 +97,6 @@ end

function ChainRulesCore.rrule(::typeof(segmented_sum_forw), args...)
y = segmented_sum_forw(args...)
grad = Δ -> (NoTangent(), segmented_sum_back(Δ, y, args...)...)
grad = Δ -> (NoTangent(), segmented_sum_back(unthunk(Δ), y, args...)...)
y, grad
end
Loading