|
36 | 36 |
|
37 | 37 | import abc
|
38 | 38 |
|
39 |
| -from copy import copy |
40 | 39 | from functools import singledispatch
|
41 |
| -from typing import Callable, List, Sequence, Tuple |
| 40 | +from typing import Sequence, Tuple |
42 | 41 |
|
43 |
| -from pytensor.graph.basic import Apply, Variable |
44 | 42 | from pytensor.graph.op import Op
|
45 | 43 | from pytensor.graph.utils import MetaType
|
46 | 44 | from pytensor.tensor import TensorVariable
|
@@ -135,107 +133,6 @@ class MeasurableVariable(abc.ABC):
|
135 | 133 | MeasurableVariable.register(RandomVariable)
|
136 | 134 |
|
137 | 135 |
|
138 |
| -class UnmeasurableMeta(MetaType): |
139 |
| - def __new__(cls, name, bases, dict): |
140 |
| - if "id_obj" not in dict: |
141 |
| - dict["id_obj"] = None |
142 |
| - |
143 |
| - return super().__new__(cls, name, bases, dict) |
144 |
| - |
145 |
| - def __eq__(self, other): |
146 |
| - if isinstance(other, UnmeasurableMeta): |
147 |
| - return hash(self.id_obj) == hash(other.id_obj) |
148 |
| - return False |
149 |
| - |
150 |
| - def __hash__(self): |
151 |
| - return hash(self.id_obj) |
152 |
| - |
153 |
| - |
154 |
| -class UnmeasurableVariable(metaclass=UnmeasurableMeta): |
155 |
| - """ |
156 |
| - id_obj is an attribute, i.e. tuple of length two, of the unmeasurable class object. |
157 |
| - e.g. id_obj = (NormalRV, noop_measurable_outputs_fn) |
158 |
| - """ |
159 |
| - |
160 |
| - |
161 |
| -def get_measurable_outputs(op: Op, node: Apply) -> List[Variable]: |
162 |
| - """Return only the outputs that are measurable.""" |
163 |
| - if isinstance(op, MeasurableVariable): |
164 |
| - return _get_measurable_outputs(op, node) |
165 |
| - else: |
166 |
| - return [] |
167 |
| - |
168 |
| - |
169 |
| -@singledispatch |
170 |
| -def _get_measurable_outputs(op, node): |
171 |
| - return node.outputs |
172 |
| - |
173 |
| - |
174 |
| -@_get_measurable_outputs.register(RandomVariable) |
175 |
| -def _get_measurable_outputs_RandomVariable(op, node): |
176 |
| - return node.outputs[1:] |
177 |
| - |
178 |
| - |
179 |
| -def noop_measurable_outputs_fn(*args, **kwargs): |
180 |
| - return [] |
181 |
| - |
182 |
| - |
183 |
| -def assign_custom_measurable_outputs( |
184 |
| - node: Apply, |
185 |
| - measurable_outputs_fn: Callable = noop_measurable_outputs_fn, |
186 |
| - type_prefix: str = "Unmeasurable", |
187 |
| -) -> Apply: |
188 |
| - """Assign a custom ``_get_measurable_outputs`` dispatch function to a measurable variable instance. |
189 |
| -
|
190 |
| - The node is cloned and a custom `Op` that's a copy of the original node's |
191 |
| - `Op` is created. That custom `Op` replaces the old `Op` in the cloned |
192 |
| - node, and then a custom dispatch implementation is created for the clone |
193 |
| - `Op` in `_get_measurable_outputs`. |
194 |
| -
|
195 |
| - If `measurable_outputs_fn` isn't specified, a no-op is used; the result is |
196 |
| - a clone of `node` that will effectively be ignored by |
197 |
| - `factorized_joint_logprob`. |
198 |
| -
|
199 |
| - Parameters |
200 |
| - ---------- |
201 |
| - node |
202 |
| - The node to recreate with a new cloned `Op`. |
203 |
| - measurable_outputs_fn |
204 |
| - The function that will be assigned to the new cloned `Op` in the |
205 |
| - `_get_measurable_outputs` dispatcher. |
206 |
| - The default is a no-op function (i.e. no measurable outputs) |
207 |
| - type_prefix |
208 |
| - The prefix used for the new type's name. |
209 |
| - The default is ``"Unmeasurable"``, which matches the default |
210 |
| - ``"measurable_outputs_fn"``. |
211 |
| - """ |
212 |
| - |
213 |
| - new_node = node.clone() |
214 |
| - op_type = type(new_node.op) |
215 |
| - |
216 |
| - if op_type in _get_measurable_outputs.registry.keys() and isinstance(op_type, UnmeasurableMeta): |
217 |
| - if _get_measurable_outputs.registry[op_type] != measurable_outputs_fn: |
218 |
| - raise ValueError( |
219 |
| - f"The type {op_type.__name__} with hash value {hash(op_type)} " |
220 |
| - "has already been dispatched a measurable outputs function." |
221 |
| - ) |
222 |
| - return node |
223 |
| - |
224 |
| - new_op_dict = op_type.__dict__.copy() |
225 |
| - new_op_dict["id_obj"] = (new_node.op, measurable_outputs_fn) |
226 |
| - new_op_dict.setdefault("original_op_type", op_type) |
227 |
| - |
228 |
| - new_op_type = type( |
229 |
| - f"{type_prefix}{op_type.__name__}", (op_type, UnmeasurableVariable), new_op_dict |
230 |
| - ) |
231 |
| - new_node.op = copy(new_node.op) |
232 |
| - new_node.op.__class__ = new_op_type |
233 |
| - |
234 |
| - _get_measurable_outputs.register(new_op_type)(measurable_outputs_fn) |
235 |
| - |
236 |
| - return new_node |
237 |
| - |
238 |
| - |
239 | 136 | class MeasurableElemwise(Elemwise):
|
240 | 137 | """Base class for Measurable Elemwise variables"""
|
241 | 138 |
|
|
0 commit comments