diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index a88d678392..f3e842cf6f 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -2753,8 +2753,10 @@ def logsumexp(x, axis=None, keepdims=False): tensor """ - - return log(sum(exp(x), axis=axis, keepdims=keepdims)) + result = log( + sum(exp(x - max(x, axis=axis, keepdims=True)), axis=axis, keepdims=keepdims) + ) + return result + max(x, axis=axis, keepdims=keepdims) # Predefine all batched variations of Dot