Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- python>=3.10
- compilers
- numpy>=1.17.0
- scipy>=0.14
- scipy>=1.14.0
- filelock
- etuples
- logical-unification
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ keywords = [
]
dependencies = [
"setuptools>=59.0.0",
"scipy>=0.14",
"scipy>=1.14.0",
"numpy>=1.17.0,<2",
"filelock",
"etuples",
Expand Down
11 changes: 11 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,17 @@ def upgrade_to_float_no_complex(*types):
return upgrade_to_float(*types)


def upgrade_to_float64_no_complex(*types):
"""
Don't accept complex, otherwise call upgrade_to_float64().

"""
for type in types:
if type in complex_types:
raise TypeError("complex argument not supported")
return upgrade_to_float64(*types)


def same_out_nocomplex(type):
if type in complex_types:
raise TypeError("complex argument not supported")
Expand Down
16 changes: 9 additions & 7 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
upcast,
upgrade_to_float,
upgrade_to_float64,
upgrade_to_float64_no_complex,
upgrade_to_float_no_complex,
)
from pytensor.scalar.basic import abs as scalar_abs
Expand Down Expand Up @@ -323,7 +324,7 @@ def c_code(self, node, name, inputs, outputs, sub):
raise NotImplementedError("only floating point is implemented")


gamma = Gamma(upgrade_to_float, name="gamma")
gamma = Gamma(upgrade_to_float64, name="gamma")


class GammaLn(UnaryScalarOp):
Expand Down Expand Up @@ -460,7 +461,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floating point is implemented")


psi = Psi(upgrade_to_float, name="psi")
psi = Psi(upgrade_to_float64, name="psi")


class TriGamma(UnaryScalarOp):
Expand Down Expand Up @@ -549,7 +550,7 @@ def c_code(self, node, name, inp, out, sub):


# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410
tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma")
tri_gamma = TriGamma(upgrade_to_float64_no_complex, name="tri_gamma")


class PolyGamma(BinaryScalarOp):
Expand Down Expand Up @@ -880,7 +881,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x):

def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x):
term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1)
sum_b += term
sum_b += term.astype(dtype)

log_gamma_k_plus_n_plus_1 += log1p(k_plus_n)
n += 1
Expand Down Expand Up @@ -1051,7 +1052,7 @@ def __hash__(self):
return hash(type(self))


gammau = GammaU(upgrade_to_float, name="gammau")
gammau = GammaU(upgrade_to_float64, name="gammau")


class GammaL(BinaryScalarOp):
Expand Down Expand Up @@ -1089,7 +1090,7 @@ def __hash__(self):
return hash(type(self))


gammal = GammaL(upgrade_to_float, name="gammal")
gammal = GammaL(upgrade_to_float64, name="gammal")


class Jv(BinaryScalarOp):
Expand Down Expand Up @@ -1335,7 +1336,7 @@ def c_code_cache_version(self):
return v


sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
sigmoid = Sigmoid(upgrade_to_float64, name="sigmoid")


class Softplus(UnaryScalarOp):
Expand Down Expand Up @@ -1631,6 +1632,7 @@ def _betainc_db_n_dq(f, p, q, n):
dK = log(x) - reciprocal(p) + psi(p + q) - psi(p)
else:
dK = log1p(-x) + psi(p + q) - psi(q)
dK = dK.astype(dtype)

derivative = np.array(0, dtype=dtype)
n = np.array(1, dtype="int16") # Enough for 200 max iters
Expand Down