We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 920f043 commit 8adacfaCopy full SHA for 8adacfa
pymc/distributions/transforms.py
@@ -300,7 +300,7 @@ def __init__(self, zerosum_axes):
300
301
@staticmethod
302
def extend_axis(array, axis):
303
- n = (array.shape[axis] + 1).astype("floatX")
+ n = pt.cast(array.shape[axis] + 1, "floatX")
304
sum_vals = array.sum(axis, keepdims=True)
305
norm = sum_vals / (pt.sqrt(n) + n)
306
fill_val = norm - sum_vals / pt.sqrt(n)
0 commit comments