|
6 | 6 |
|
7 | 7 | import os
|
8 | 8 | import warnings
|
| 9 | +from textwrap import dedent |
9 | 10 |
|
10 | 11 | import numpy as np
|
11 | 12 | import scipy.special
|
@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
|
1134 | 1135 | r"""
|
1135 | 1136 | Compute log(1 + exp(x)), also known as softplus or log1pexp
|
1136 | 1137 |
|
1137 |
| - This function is numerically more stable than the naive approach. |
| 1138 | + This function is numerically faster than the naive approach, and does not overflow |
| 1139 | + for large values of x. |
1138 | 1140 |
|
1139 | 1141 | For details, see
|
1140 | 1142 | https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
|
@@ -1172,52 +1174,38 @@ def grad(self, inp, grads):
|
1172 | 1174 | def c_code(self, node, name, inp, out, sub):
|
1173 | 1175 | (x,) = inp
|
1174 | 1176 | (z,) = out
|
1175 |
| - # The boundary constants were obtained by looking at the output of |
1176 |
| - # python commands like: |
1177 |
| - # import numpy, pytensor |
1178 |
| - # dt='float32' # or float64 |
1179 |
| - # for i in range(750): |
1180 |
| - # print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt)))) |
1181 |
| - # the upper boundary check prevents us from generating inf, whereas the |
1182 |
| - # the lower boundary check prevents using exp when the result will be 0 anyway. |
1183 |
| - # The intermediate constants are taken from Machler (2012). |
1184 |
| - |
1185 |
| - # We use the float32 limits for float16 for now as the |
1186 |
| - # computation will happen in float32 anyway. |
| 1177 | + # We use the same limits for all precisions, which may be suboptimal. The reference |
| 1178 | + # paper only looked at double precision |
1187 | 1179 | if node.inputs[0].type in float_types:
|
1188 | 1180 | if node.inputs[0].type == float64:
|
1189 |
| - return ( |
1190 |
| - """ |
1191 |
| - %(z)s = ( |
1192 |
| - %(x)s < -745.0 ? 0.0 : |
1193 |
| - %(x)s < -37.0 ? exp(%(x)s) : |
1194 |
| - %(x)s < 18.0 ? log1p(exp(%(x)s)) : |
1195 |
| - %(x)s < 33.3 ? %(x)s + exp(-%(x)s) : |
1196 |
| - %(x)s |
| 1181 | + return dedent( |
| 1182 | + f""" |
| 1183 | + {z} = ( |
| 1184 | + {x} < -37.0 ? exp({x}) : |
| 1185 | + {x} < 18.0 ? log1p(exp({x})) : |
| 1186 | + {x} < 33.3 ? {x} + exp(-{x}) : |
| 1187 | + {x} |
1197 | 1188 | );
|
1198 | 1189 | """
|
1199 |
| - % locals() |
1200 | 1190 | )
|
1201 | 1191 | else:
|
1202 |
| - return ( |
1203 |
| - """ |
1204 |
| - %(z)s = ( |
1205 |
| - %(x)s < -103.0f ? 0.0 : |
1206 |
| - %(x)s < -37.0f ? exp(%(x)s) : |
1207 |
| - %(x)s < 18.0f ? log1p(exp(%(x)s)) : |
1208 |
| - %(x)s < 33.3f ? %(x)s + exp(-%(x)s) : |
1209 |
| - %(x)s |
| 1192 | + return dedent( |
| 1193 | + f""" |
| 1194 | + {z} = ( |
| 1195 | + {x} < -37.0f ? exp({x}) : |
| 1196 | + {x} < 18.0f ? log1p(exp({x})) : |
| 1197 | + {x} < 33.3f ? {x} + exp(-{x}) : |
| 1198 | + {x} |
1210 | 1199 | );
|
1211 | 1200 | """
|
1212 |
| - % locals() |
1213 | 1201 | )
|
1214 | 1202 | else:
|
1215 | 1203 | raise NotImplementedError("only floatingpoint is implemented")
|
1216 | 1204 |
|
1217 | 1205 | def c_code_cache_version(self):
|
1218 | 1206 | v = super().c_code_cache_version()
|
1219 | 1207 | if v:
|
1220 |
| - return (2,) + v |
| 1208 | + return (3,) + v |
1221 | 1209 | else:
|
1222 | 1210 | return v
|
1223 | 1211 |
|
|
0 commit comments