@@ -226,103 +226,6 @@ logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))
226
226
"""
227
227
$(SIGNATURES)
228
228
229
- Compute `log(sum(exp, X))` in a numerically stable way that avoids intermediate over- and
230
- underflow.
231
-
232
- `X` should be an iterator of real numbers. The result is computed using a single pass over
233
- the data.
234
-
235
- # References
236
-
237
- [Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
238
- """
239
- logsumexp (X) = _logsumexp_onepass (X)
240
-
241
- """
242
- $(SIGNATURES)
243
-
244
- Compute `log.(sum(exp.(X); dims=dims))` in a numerically stable way that avoids
245
- intermediate over- and underflow.
246
-
247
- The result is computed using a single pass over the data.
248
-
249
- # References
250
-
251
- [Sebastian Nowozin: Streaming Log-sum-exp Computation](http://www.nowozin.net/sebastian/blog/streaming-log-sum-exp-computation.html)
252
- """
253
- logsumexp (X:: AbstractArray{<:Real} ; dims= :) = _logsumexp (X, dims)
254
-
255
- _logsumexp (X:: AbstractArray{<:Real} , :: Colon ) = _logsumexp_onepass (X)
256
- function _logsumexp (X:: AbstractArray{<:Real} , dims)
257
- # Do not use log(zero(eltype(X))) directly to avoid issues with ForwardDiff (#82)
258
- FT = float (eltype (X))
259
- xmax_r = reduce (_logsumexp_onepass_op, X; dims= dims, init= (FT (- Inf ), zero (FT)))
260
- return @. first (xmax_r) + log1p (last (xmax_r))
261
- end
262
-
263
- function _logsumexp_onepass (X)
264
- # fallback for empty collections
265
- isempty (X) && return log (sum (X))
266
- return _logsumexp_onepass_result (_logsumexp_onepass_reduce (X, Base. IteratorEltype (X)))
267
- end
268
-
269
- # function barrier for reductions with single element and without initial element
270
- _logsumexp_onepass_result (x) = float (x)
271
- _logsumexp_onepass_result ((xmax, r):: Tuple ) = xmax + log1p (r)
272
-
273
- # iterables with known element type
274
- function _logsumexp_onepass_reduce (X, :: Base.HasEltype )
275
- # do not perform type computations if element type is abstract
276
- T = eltype (X)
277
- isconcretetype (T) || return _logsumexp_onepass_reduce (X, Base. EltypeUnknown ())
278
-
279
- FT = float (T)
280
- return reduce (_logsumexp_onepass_op, X; init= (FT (- Inf ), zero (FT)))
281
- end
282
-
283
- # iterables without known element type
284
- _logsumexp_onepass_reduce (X, :: Base.EltypeUnknown ) = reduce (_logsumexp_onepass_op, X)
285
-
286
- # # Reductions for one-pass algorithm: avoid expensive multiplications if numbers are reduced
287
-
288
- # reduce two numbers
289
- function _logsumexp_onepass_op (x1, x2)
290
- a = x1 == x2 ? zero (x1 - x2) : - abs (x1 - x2)
291
- xmax = x1 > x2 ? oftype (a, x1) : oftype (a, x2)
292
- r = exp (a)
293
- return xmax, r
294
- end
295
-
296
- # reduce a number and a partial sum
297
- function _logsumexp_onepass_op (x, (xmax, r):: Tuple )
298
- a = x == xmax ? zero (x - xmax) : - abs (x - xmax)
299
- if x > xmax
300
- _xmax = oftype (a, x)
301
- _r = (r + one (r)) * exp (a)
302
- else
303
- _xmax = oftype (a, xmax)
304
- _r = r + exp (a)
305
- end
306
- return _xmax, _r
307
- end
308
- _logsumexp_onepass_op (xmax_r:: Tuple , x) = _logsumexp_onepass_op (x, xmax_r)
309
-
310
- # reduce two partial sums
311
- function _logsumexp_onepass_op ((xmax1, r1):: Tuple , (xmax2, r2):: Tuple )
312
- a = xmax1 == xmax2 ? zero (xmax1 - xmax2) : - abs (xmax1 - xmax2)
313
- if xmax1 > xmax2
314
- xmax = oftype (a, xmax1)
315
- r = r1 + (r2 + one (r2)) * exp (a)
316
- else
317
- xmax = oftype (a, xmax2)
318
- r = r2 + (r1 + one (r1)) * exp (a)
319
- end
320
- return xmax, r
321
- end
322
-
323
- """
324
- $(SIGNATURES)
325
-
326
229
Overwrite `r` with the `softmax` (or _normalized exponential_) transformation of `x`
327
230
328
231
That is, `r` is overwritten with `exp.(x)`, normalized to sum to 1.
0 commit comments