Skip to content

Commit accfef4

Browse files
authored
Fixing type issue in scaledmaskmean
1 parent eb3e8e6 commit accfef4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/loss.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ rkl(P,Q) = sum(tensor(Q) .* (log.(tensor(Q)) .- logsoftmax(tensor(P))), dims=1)
55

66
function scaledmaskedmean(l::AbstractArray{T}, c::Union{AbstractArray, Real}, m::Union{AbstractArray, Real}) where T
77
expanded_m = expand(m, ndims(l))
8-
mean(l .* expand(c, ndims(l)) .* expanded_m) / (mean(expanded_m) + T(1e-6))
8+
T(mean(l .* expand(c, ndims(l)) .* expanded_m) / ((sum(expanded_m)/T(length(expanded_m))) + T(1e-6)))
99
end
1010

1111
scaledmaskedmean(l::AbstractArray, c::Union{AbstractArray, Real}, m::Nothing) = mean(l .* expand(c, ndims(l)))

0 commit comments

Comments
 (0)