Skip to content

Commit 8968a38

Browse files
committed
Harmonize softplus implementations
1 parent 671a821 commit 8968a38

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,18 @@ def psi(x):
125125
@jax_funcify.register(Softplus)
126126
def jax_funcify_Softplus(op, **kwargs):
127127
def softplus(x):
128-
# This expression is numerically equivalent to the PyTensor one
129-
# It just contains one "speed" optimization less than the PyTensor counterpart
130128
return jnp.where(
131-
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
129+
x < -37.0,
130+
jnp.exp(x),
131+
jnp.where(
132+
x < 18.0,
133+
jnp.log1p(jnp.exp(x)),
134+
jnp.where(
135+
x < 33.3,
136+
x + jnp.exp(-x),
137+
x,
138+
),
139+
),
132140
)
133141

134142
return softplus

pytensor/scalar/math.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
import warnings
9+
from textwrap import dedent
910

1011
import numpy as np
1112
import scipy.special
@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
11341135
r"""
11351136
Compute log(1 + exp(x)), also known as softplus or log1pexp
11361137
1137-
This function is numerically more stable than the naive approach.
1138+
This function is numerically faster than the naive approach, and does not overflow
1139+
for large values of x.
11381140
11391141
For details, see
11401142
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
@@ -1172,52 +1174,38 @@ def grad(self, inp, grads):
11721174
def c_code(self, node, name, inp, out, sub):
11731175
(x,) = inp
11741176
(z,) = out
1175-
# The boundary constants were obtained by looking at the output of
1176-
# python commands like:
1177-
# import numpy, pytensor
1178-
# dt='float32' # or float64
1179-
# for i in range(750):
1180-
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
1181-
# the upper boundary check prevents us from generating inf, whereas the
1182-
# the lower boundary check prevents using exp when the result will be 0 anyway.
1183-
# The intermediate constants are taken from Machler (2012).
1184-
1185-
# We use the float32 limits for float16 for now as the
1186-
# computation will happen in float32 anyway.
1177+
# We use the same limits for all precisions, which may be suboptimal. The reference
1178+
# paper only looked at double precision
11871179
if node.inputs[0].type in float_types:
11881180
if node.inputs[0].type == float64:
1189-
return (
1190-
"""
1191-
%(z)s = (
1192-
%(x)s < -745.0 ? 0.0 :
1193-
%(x)s < -37.0 ? exp(%(x)s) :
1194-
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
1195-
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
1196-
%(x)s
1181+
return dedent(
1182+
f"""
1183+
{z} = (
1184+
{x} < -37.0 ? exp({x}) :
1185+
{x} < 18.0 ? log1p(exp({x})) :
1186+
{x} < 33.3 ? {x} + exp(-{x}) :
1187+
{x}
11971188
);
11981189
"""
1199-
% locals()
12001190
)
12011191
else:
1202-
return (
1203-
"""
1204-
%(z)s = (
1205-
%(x)s < -103.0f ? 0.0 :
1206-
%(x)s < -37.0f ? exp(%(x)s) :
1207-
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
1208-
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
1209-
%(x)s
1192+
return dedent(
1193+
f"""
1194+
{z} = (
1195+
{x} < -37.0f ? exp({x}) :
1196+
{x} < 18.0f ? log1p(exp({x})) :
1197+
{x} < 33.3f ? {x} + exp(-{x}) :
1198+
{x}
12101199
);
12111200
"""
1212-
% locals()
12131201
)
12141202
else:
12151203
raise NotImplementedError("only floatingpoint is implemented")
12161204

12171205
def c_code_cache_version(self):
12181206
v = super().c_code_cache_version()
12191207
if v:
1220-
return (2,) + v
1208+
return (3,) + v
12211209
else:
12221210
return v
12231211

0 commit comments

Comments
 (0)