Skip to content

Commit fa3ed8c

Browse files
committed
Raise explicitly on Python methods that are incompatible with lazy variables
Notably changes the behavior of `__bool__` to always raise. Before there was a hack based on whether a variable had been compared to somethig before.
1 parent 884dee9 commit fa3ed8c

File tree

2 files changed

+62
-26
lines changed

2 files changed

+62
-26
lines changed

pytensor/scalar/basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,37 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
725725

726726

727727
class _scalar_py_operators:
728+
# These can't work because Python requires native output types
729+
def __bool__(self):
730+
raise TypeError(
731+
"ScalarVariable cannot be converted to Python boolean. "
732+
"Call `.astype(bool)` for the symbolic equivalent."
733+
)
734+
735+
def __index__(self):
736+
raise TypeError(
737+
"ScalarVariable cannot be converted to Python integer. "
738+
"Call `.astype(int)` for the symbolic equivalent."
739+
)
740+
741+
def __int__(self):
742+
raise TypeError(
743+
"ScalarVariable cannot be converted to Python integer. "
744+
"Call `.astype(int)` for the symbolic equivalent."
745+
)
746+
747+
def __float__(self):
748+
raise TypeError(
749+
"ScalarVariable cannot be converted to Python float. "
750+
"Call `.astype(float)` for the symbolic equivalent."
751+
)
752+
753+
def __complex__(self):
754+
raise TypeError(
755+
"ScalarVariable cannot be converted to Python complex number. "
756+
"Call `.astype(complex)` for the symbolic equivalent."
757+
)
758+
728759
# So that we can simplify checking code when we have a mixture of ScalarType
729760
# variables and Tensor variables
730761
ndim = 0

pytensor/tensor/variable.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,54 +26,59 @@
2626

2727

2828
class _tensor_py_operators:
29+
# These can't work because Python requires native output types
30+
def __bool__(self):
31+
raise TypeError(
32+
"TensorVariable cannot be converted to Python boolean. "
33+
"Call `.astype(bool)` for the symbolic equivalent."
34+
)
35+
36+
def __index__(self):
37+
raise TypeError(
38+
"TensorVariable cannot be converted to Python integer. "
39+
"Call `.astype(int)` for the symbolic equivalent."
40+
)
41+
42+
def __int__(self):
43+
raise TypeError(
44+
"TensorVariable cannot be converted to Python integer. "
45+
"Call `.astype(int)` for the symbolic equivalent."
46+
)
47+
48+
def __float__(self):
49+
raise TypeError(
50+
"TensorVariables cannot be converted to Python float. "
51+
"Call `.astype(float)` for the symbolic equivalent."
52+
)
53+
54+
def __complex__(self):
55+
raise TypeError(
56+
"TensorVariables cannot be converted to Python complex number. "
57+
"Call `.astype(complex)` for the symbolic equivalent."
58+
)
59+
2960
def __abs__(self):
3061
return pt.math.abs(self)
3162

3263
def __neg__(self):
3364
return pt.math.neg(self)
3465

35-
# These won't work because Python requires an int return value
36-
# def __int__(self): return convert_to_int32(self)
37-
# def __float__(self): return convert_to_float64(self)
38-
# def __complex__(self): return convert_to_complex128(self)
39-
40-
_is_nonzero = True
41-
4266
def __lt__(self, other):
4367
rval = pt.math.lt(self, other)
44-
rval._is_nonzero = False
4568
return rval
4669

4770
def __le__(self, other):
4871
rval = pt.math.le(self, other)
49-
rval._is_nonzero = False
5072
return rval
5173

5274
def __gt__(self, other):
5375
rval = pt.math.gt(self, other)
54-
rval._is_nonzero = False
5576
return rval
5677

5778
def __ge__(self, other):
5879
rval = pt.math.ge(self, other)
59-
rval._is_nonzero = False
6080
return rval
6181

62-
def __bool__(self):
63-
# This is meant to prohibit stuff like a < b < c, which is internally
64-
# implemented as (a < b) and (b < c). The trouble with this is the
65-
# side-effect that checking for a non-NULL a by typing "if a: ..."
66-
# uses the same __nonzero__ method. We want these both to work, but
67-
# it seems impossible. Currently, all vars evaluate to nonzero except
68-
# the return values of comparison operators, which raise this
69-
# exception. If you can think of a better solution, go for it!
70-
#
71-
# __bool__ is Python 3.x data model. __nonzero__ is Python 2.x.
72-
if self._is_nonzero:
73-
return True
74-
else:
75-
raise TypeError("Variables do not support boolean operations.")
76-
7782
def __invert__(self):
7883
return pt.math.invert(self)
7984

0 commit comments

Comments
 (0)