207
207
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
208
208
"""
209
209
function logaddexp (x:: Real , y:: Real )
210
- # ensure Δ = 0 if x = y = Inf
210
+ # ensure Δ = 0 if x = y = ± Inf
211
211
Δ = ifelse (x == y, zero (x - y), abs (x - y))
212
212
max (x, y) + log1pexp (- Δ)
213
213
end
@@ -224,28 +224,99 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
224
224
"""
225
225
logsumexp(X)
226
226
227
- Compute `log(sum(exp, X))`, evaluated avoiding intermediate overflow/undeflow.
227
+ Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
228
+ underflow.
229
+
230
+ `X` should be an iterator of real numbers. The result is computed using a single pass over
231
+ the data.
232
+
233
+ # References
234
+
235
+ [Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
236
+ """
237
+ logsumexp (X) = _logsumexp_onepass (X)
228
238
229
- `X` should be an iterator of real numbers.
230
239
"""
231
- function logsumexp (X)
240
+ logsumexp(X::AbstractArray{<:Real}; dims=:)
241
+
242
+ Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
243
+ intermediate over- and underflow.
244
+
245
+ The result is computed using a single pass over the data.
246
+
247
+ # References
248
+
249
+ [Sebastian Nowozin: Streaming Log-sum-exp Computation.](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
250
+ """
251
+ logsumexp (X:: AbstractArray{<:Real} ; dims= :) = _logsumexp (X, dims)
252
+
253
+ _logsumexp (X:: AbstractArray{<:Real} , :: Colon ) = _logsumexp_onepass (X)
254
+ function _logsumexp (X:: AbstractArray{<:Real} , dims)
255
+ # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
256
+ FT = float (eltype (X))
257
+ xmax_r = reduce (_logsumexp_onepass_op, X; dims= dims, init= (FT (- Inf ), zero (FT)))
258
+ return @. first (xmax_r) + log1p (last (xmax_r))
259
+ end
260
+
261
+ function _logsumexp_onepass (X)
262
+ # fallback for empty collections
232
263
isempty (X) && return log (sum (X))
233
- reduce (logaddexp, X )
264
+ return _logsumexp_onepass_result ( _logsumexp_onepass_reduce (X, Base . IteratorEltype (X)) )
234
265
end
235
- function logsumexp (X:: AbstractArray{T} ; dims= :) where {T<: Real }
236
- # Do not use log(zero(T)) directly to avoid issues with ForwardDiff (#82)
237
- u = reduce (max, X, dims= dims, init= oftype (log (zero (T)), - Inf ))
238
- u isa AbstractArray || isfinite (u) || return float (u)
239
- let u= u # avoid https://github.com/JuliaLang/julia/issues/15276
240
- # TODO : remove the branch when JuliaLang/julia#31020 is merged.
241
- if u isa AbstractArray
242
- u .+ log .(sum (exp .(X .- u); dims= dims))
243
- else
244
- u + log (sum (x -> exp (x- u), X))
245
- end
246
- end
266
+
267
+ # function barrier for reductions with single element and without initial element
268
+ _logsumexp_onepass_result (x) = float (x)
269
+ _logsumexp_onepass_result ((xmax, r):: Tuple ) = xmax + log1p (r)
270
+
271
+ # iterables with known element type
272
+ function _logsumexp_onepass_reduce (X, :: Base.HasEltype )
273
+ # do not perform type computations if element type is abstract
274
+ T = eltype (X)
275
+ isconcretetype (T) || return _logsumexp_onepass_reduce (X, Base. EltypeUnknown ())
276
+
277
+ FT = float (T)
278
+ return reduce (_logsumexp_onepass_op, X; init= (FT (- Inf ), zero (FT)))
279
+ end
280
+
281
+ # iterables without known element type
282
+ _logsumexp_onepass_reduce (X, :: Base.EltypeUnknown ) = reduce (_logsumexp_onepass_op, X)
283
+
284
+ # # Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
285
+
286
+ # reduce two numbers
287
+ function _logsumexp_onepass_op (x1, x2)
288
+ a = x1 == x2 ? zero (x1 - x2) : - abs (x1 - x2)
289
+ xmax = x1 > x2 ? oftype (a, x1) : oftype (a, x2)
290
+ r = exp (a)
291
+ return xmax, r
247
292
end
248
293
294
+ # reduce a number and a partial sum
295
+ function _logsumexp_onepass_op (x, (xmax, r):: Tuple )
296
+ a = x == xmax ? zero (x - xmax) : - abs (x - xmax)
297
+ if x > xmax
298
+ _xmax = oftype (a, x)
299
+ _r = (r + one (r)) * exp (a)
300
+ else
301
+ _xmax = oftype (a, xmax)
302
+ _r = r + exp (a)
303
+ end
304
+ return _xmax, _r
305
+ end
306
+ _logsumexp_onepass_op (xmax_r:: Tuple , x) = _logsumexp_onepass_op (x, xmax_r)
307
+
308
+ # reduce two partial sums
309
+ function _logsumexp_onepass_op ((xmax1, r1):: Tuple , (xmax2, r2):: Tuple )
310
+ a = xmax1 == xmax2 ? zero (xmax1 - xmax2) : - abs (xmax1 - xmax2)
311
+ if xmax1 > xmax2
312
+ xmax = oftype (a, xmax1)
313
+ r = r1 + (r2 + one (r2)) * exp (a)
314
+ else
315
+ xmax = oftype (a, xmax2)
316
+ r = r2 + (r1 + one (r1)) * exp (a)
317
+ end
318
+ return xmax, r
319
+ end
249
320
250
321
"""
251
322
softmax!(r::AbstractArray, x::AbstractArray)
0 commit comments