We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eb3e8e6 commit accfef4Copy full SHA for accfef4
src/loss.jl
@@ -5,7 +5,7 @@ rkl(P,Q) = sum(tensor(Q) .* (log.(tensor(Q)) .- logsoftmax(tensor(P))), dims=1)
5
6
function scaledmaskedmean(l::AbstractArray{T}, c::Union{AbstractArray, Real}, m::Union{AbstractArray, Real}) where T
7
expanded_m = expand(m, ndims(l))
8
- mean(l .* expand(c, ndims(l)) .* expanded_m) / (mean(expanded_m) + T(1e-6))
+ T(mean(l .* expand(c, ndims(l)) .* expanded_m) / ((sum(expanded_m)/T(length(expanded_m))) + T(1e-6)))
9
end
10
11
scaledmaskedmean(l::AbstractArray, c::Union{AbstractArray, Real}, m::Nothing) = mean(l .* expand(c, ndims(l)))
0 commit comments