Skip to content

Commit ebd03df

Browse files
Specialize where it may be worthwhile
1 parent 8506c50 commit ebd03df

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/vstats.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ _klterm(x::T, y::T) where {T} = _xlogy(x, x) - _xlogy(x, y)
1616

1717
################
1818
# Means
19-
function vmean(f, A; dims=:)
19+
function vmean(f::F, A; dims=:) where {F}
2020
c = 1 / _denom(A, dims)
2121
vmapreducethen(f, +, x -> c * x, A, dims=dims)
2222
end
2323
vmean(A; dims=:) = vmean(identity, A, dims=dims)
2424

25-
function vtmean(f, A; dims=:)
25+
function vtmean(f::F, A; dims=:) where {F}
2626
c = 1 / _denom(A, dims)
2727
vtmapreducethen(f, +, x -> c * x, A, dims=dims)
2828
end
@@ -54,25 +54,29 @@ function vtharmmean(A; dims=:)
5454
vtmapreducethen(inv, +, x -> inv(c * x), A, dims=dims)
5555
end
5656

57-
# Mean on the log scaley
58-
function vmean_log(f, A; dims=:)
57+
# Mean on the log scale
58+
function vmean_log(f::F, A; dims=:) where {F}
5959
c = log(_denom(A, dims))
6060
vmapreducethen(f, +, x -> log(x) - c, A, dims=dims)
6161
end
62+
function vtmean_log(f::F, A; dims=:) where {F}
63+
c = log(_denom(A, dims))
64+
vtmapreducethen(f, +, x -> log(x) - c, A, dims=dims)
65+
end
6266

6367
################
6468
# logsumexp (the naive and unsafe version)
6569
# Naturally, faster than the overflow/underflow-safe logsumexp, but if one can tolerate it...
6670
vlse(A; dims=:) = vmapreducethen(exp, +, log, A, dims=dims)
6771
vtlse(A; dims=:) = vtmapreducethen(exp, +, log, A, dims=dims)
68-
vlse(f, A; dims=:) = vmapreducethen(x -> exp(f(x)), +, log, A, dims=dims)
69-
vtlse(f, A; dims=:) = vtmapreducethen(x -> exp(f(x)), +, log, A, dims=dims)
72+
vlse(f::F, A; dims=:) where {F} = vmapreducethen(x -> exp(f(x)), +, log, A, dims=dims)
73+
vtlse(f::F, A; dims=:) where {F} = vtmapreducethen(x -> exp(f(x)), +, log, A, dims=dims)
7074

7175
function vlse_mean(A; dims=:)
7276
c = log(_denom(A, dims))
7377
vmapreducethen(exp, +, x -> log(x) - c, A, dims=dims)
7478
end
75-
function vlse_mean(f, A; dims=:)
79+
function vlse_mean(f::F, A; dims=:) where {F}
7680
c = log(_denom(A, dims))
7781
vmapreducethen(x -> exp(f(x)), +, x -> log(x) - c, A, dims=dims)
7882
end
@@ -81,7 +85,7 @@ function vtlse_mean(A; dims=:)
8185
c = log(_denom(A, dims))
8286
vtmapreducethen(exp, +, x -> log(x) - c, A, dims=dims)
8387
end
84-
function vtlse_mean(f, A; dims=:)
88+
function vtlse_mean(f::F, A; dims=:) where {F}
8589
c = log(_denom(A, dims))
8690
vtmapreducethen(x -> exp(f(x)), +, x -> log(x) - c, A, dims=dims)
8791
end
@@ -196,9 +200,6 @@ function vrenyadivergence(p, q, α::Real; dims=:)
196200
end
197201
end
198202

199-
200-
201-
202203
################
203204
# Deviations
204205

0 commit comments

Comments
 (0)