35
35
# SOFTWARE.
36
36
37
37
38
+ from typing import cast
39
+
38
40
import pytensor .tensor as pt
39
41
40
- from pytensor .graph .basic import Node
42
+ from pytensor .graph .basic import Apply
41
43
from pytensor .graph .fg import FunctionGraph
42
44
from pytensor .graph .rewriting .basic import node_rewriter
43
45
from pytensor .tensor .elemwise import Elemwise
@@ -72,15 +74,15 @@ class MeasurableMaxDiscrete(Max):
72
74
73
75
74
76
@node_rewriter ([Max ])
75
- def find_measurable_max (fgraph : FunctionGraph , node : Node ) -> list [TensorVariable ] | None :
77
+ def find_measurable_max (fgraph : FunctionGraph , node : Apply ) -> list [TensorVariable ] | None :
76
78
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
77
79
if rv_map_feature is None :
78
80
return None # pragma: no cover
79
81
80
82
if isinstance (node .op , MeasurableMax ):
81
83
return None # pragma: no cover
82
84
83
- base_var = node .inputs [0 ]
85
+ base_var = cast ( TensorVariable , node .inputs [0 ])
84
86
85
87
if base_var .owner is None :
86
88
return None
@@ -104,6 +106,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariabl
104
106
return None
105
107
106
108
# distinguish measurable discrete and continuous (because logprob is different)
109
+ measurable_max : Max
107
110
if base_var .owner .op .dtype .startswith ("int" ):
108
111
measurable_max = MeasurableMaxDiscrete (list (axis ))
109
112
else :
@@ -173,7 +176,7 @@ class MeasurableDiscreteMaxNeg(Max):
173
176
174
177
175
178
@node_rewriter (tracks = [Max ])
176
- def find_measurable_max_neg (fgraph : FunctionGraph , node : Node ) -> list [TensorVariable ] | None :
179
+ def find_measurable_max_neg (fgraph : FunctionGraph , node : Apply ) -> list [TensorVariable ] | None :
177
180
rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
178
181
179
182
if rv_map_feature is None :
@@ -182,7 +185,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
182
185
if isinstance (node .op , MeasurableMaxNeg ):
183
186
return None # pragma: no cover
184
187
185
- base_var = node .inputs [0 ]
188
+ base_var = cast ( TensorVariable , node .inputs [0 ])
186
189
187
190
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
188
191
if not (base_var .owner is not None and isinstance (base_var .owner .op , Elemwise )):
@@ -213,6 +216,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
213
216
return None
214
217
215
218
# distinguish measurable discrete and continuous (because logprob is different)
219
+ measurable_min : Max
216
220
if base_rv .owner .op .dtype .startswith ("int" ):
217
221
measurable_min = MeasurableDiscreteMaxNeg (list (axis ))
218
222
else :
0 commit comments