|
3 | 3 | import itertools
|
4 | 4 | import threading
|
5 | 5 | import warnings
|
6 |
| -from typing import Optional, Tuple, TypeVar, Type, List |
| 6 | +from typing import Optional, Tuple, TypeVar, Type, List, Union |
| 7 | +from sys import modules |
7 | 8 |
|
8 | 9 | import numpy as np
|
9 | 10 | from pandas import Series
|
@@ -162,59 +163,103 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
|
162 | 163 | node_children.update(temp_tree)
|
163 | 164 | return leaf_nodes, node_parents, node_children
|
164 | 165 |
|
165 |
| -T = TypeVar('T', bound='Context') |
| 166 | +T = TypeVar('T', bound='ContextMeta') |
166 | 167 |
|
167 |
| -class Context: |
| 168 | + |
| 169 | +class ContextMeta(type): |
168 | 170 | """Functionality for objects that put themselves in a context using
|
169 | 171 | the `with` statement.
|
170 | 172 | """
|
171 |
| - contexts = threading.local() |
172 |
| - |
173 |
| - def __enter__(self): |
174 |
| - type(self).get_contexts().append(self) |
175 |
| - # self._theano_config is set in Model.__new__ |
176 |
| - if hasattr(self, '_theano_config'): |
177 |
| - self._old_theano_config = set_theano_conf(self._theano_config) |
178 |
| - return self |
179 |
| - |
180 |
| - def __exit__(self, typ, value, traceback): |
181 |
| - type(self).get_contexts().pop() |
182 |
| - # self._theano_config is set in Model.__new__ |
183 |
| - if hasattr(self, '_old_theano_config'): |
184 |
| - set_theano_conf(self._old_theano_config) |
185 |
| - |
186 |
| - @classmethod |
187 |
| - def get_contexts(cls: Type[T]) -> List[T]: |
188 |
| - # no race-condition here, cls.contexts is a thread-local object |
189 |
| - # be sure not to override contexts in a subclass however! |
190 |
| - if not hasattr(cls.contexts, 'stack'): |
191 |
| - cls.contexts.stack = [] |
192 |
| - return cls.contexts.stack |
193 |
| - |
194 |
| - @classmethod |
195 |
| - def get_context(cls: Type[T], error_if_none=True) -> Optional[T]: |
| 173 | + _context_class = None # type: Union[Type, str] |
| 174 | + |
| 175 | + def __new__(cls, name, bases, dct, **kargs): |
| 176 | + # DO NOT send "**kargs" to "type.__new__". It won't catch them and |
| 177 | + # you'll get a "TypeError: type() takes 1 or 3 arguments" exception. |
| 178 | + # dct['get_context'] = classmethod(_get_context) |
| 179 | + # dct['get_contexts'] = classmethod(_get_contexts) |
| 180 | + return super().__new__(cls, name, bases, dct) |
| 181 | + |
| 182 | + def __init__(cls, name, bases, nmspc, context_class: Optional[Type]=None, **kwargs): |
| 183 | + if context_class is not None: |
| 184 | + cls._context_class = context_class |
| 185 | + super().__init__(name, bases, nmspc) |
| 186 | + cls.contexts = threading.local() |
| 187 | + def __enter__(self): |
| 188 | + self.__class__.context_class.get_contexts().append(self) |
| 189 | + # self._theano_config is set in Model.__new__ |
| 190 | + if hasattr(self, '_theano_config'): |
| 191 | + self._old_theano_config = set_theano_conf(self._theano_config) |
| 192 | + return self |
| 193 | + |
| 194 | + def __exit__(self, typ, value, traceback): |
| 195 | + self.__class__.context_class.get_contexts().pop() |
| 196 | + # self._theano_config is set in Model.__new__ |
| 197 | + if hasattr(self, '_old_theano_config'): |
| 198 | + set_theano_conf(self._old_theano_config) |
| 199 | + |
| 200 | + cls.__enter__ = __enter__ |
| 201 | + cls.__exit__ = __exit__ |
| 202 | + |
| 203 | + def get_context(cls, error_if_none=True) -> Optional[T]: |
196 | 204 | """Return the most recently pushed context object of type ``cls``
|
197 | 205 | on the stack, or ``None``."""
|
198 | 206 | idx = -1
|
199 | 207 | while True:
|
200 | 208 | try:
|
201 |
| - candidate = cls.get_contexts()[idx] |
| 209 | + candidate = cls.get_contexts()[idx] # type: Optional[T] |
202 | 210 | except IndexError as e:
|
| 211 | + # Calling code expects to get a TypeError if the entity |
| 212 | + # is unfound, and there's too much to fix. |
203 | 213 | if error_if_none:
|
204 |
| - raise e |
| 214 | + raise TypeError("No %s on context stack"%str(cls)) |
205 | 215 | return None
|
206 |
| - if isinstance(candidate, cls): |
207 |
| - return candidate |
| 216 | + return candidate |
208 | 217 | idx = idx - 1
|
209 | 218 |
|
| 219 | + def get_contexts(cls) -> List[T]: |
| 220 | + # no race-condition here, cls.contexts is a thread-local object |
| 221 | + # be sure not to override contexts in a subclass however! |
| 222 | + if not hasattr(cls.context_class, 'stack'): |
| 223 | + cls.context_class.stack = [] |
| 224 | + return cls.context_class.stack |
| 225 | + |
| 226 | + @property |
| 227 | + def context_class(cls) -> Type: |
| 228 | + def resolve_type(c: Union[Type, str]) -> Type: |
| 229 | + if isinstance(c, str): |
| 230 | + c = getattr(modules[cls.__module__], c) |
| 231 | + if isinstance(c, type): |
| 232 | + return c |
| 233 | + raise ValueError("Cannot resolve context class %s"%c) |
| 234 | + assert cls is not None |
| 235 | + if isinstance(cls._context_class, str): |
| 236 | + cls._context_class = resolve_type(cls._context_class) |
| 237 | + if not isinstance(cls._context_class, (str, type)): |
| 238 | + raise ValueError("Context class for %s, %s, is not of the right type"%\ |
| 239 | + (cls.__name__, cls._context_class)) |
| 240 | + return cls._context_class |
| 241 | + |
| 242 | + def __init_subclass__(cls, **kwargs): |
| 243 | + super().__init_subclass__(**kwargs) |
| 244 | + cls.context_class = super().context_class |
| 245 | + |
| 246 | + # Initialize object in its own context... |
| 247 | + def __call__(cls, *args, **kwargs): |
| 248 | + instance = cls.__new__(cls, *args, **kwargs) |
| 249 | + with instance: # appends context |
| 250 | + instance.__init__(*args, **kwargs) |
| 251 | + return instance |
| 252 | + |
| 253 | + |
| 254 | + |
210 | 255 | def modelcontext(model: Optional['Model']) -> Optional['Model']:
|
211 | 256 | """return the given model or try to find it in the context if there was
|
212 | 257 | none supplied.
|
213 | 258 | """
|
214 | 259 | if model is None:
|
215 | 260 | found: Optional['Model'] = Model.get_context(error_if_none=False)
|
216 | 261 | if found is None:
|
217 |
| - raise ValueError("No pymc3 model object on context stack.") |
| 262 | + raise ValueError("No model on context stack.") |
218 | 263 | return found
|
219 | 264 | return model
|
220 | 265 |
|
@@ -303,15 +348,6 @@ def logp_nojact(self):
|
303 | 348 | return logp
|
304 | 349 |
|
305 | 350 |
|
306 |
| -class InitContextMeta(type): |
307 |
| - """Metaclass that executes `__init__` of instance in it's context""" |
308 |
| - def __call__(cls, *args, **kwargs): |
309 |
| - instance = cls.__new__(cls, *args, **kwargs) |
310 |
| - with instance: # appends context |
311 |
| - instance.__init__(*args, **kwargs) |
312 |
| - return instance |
313 |
| - |
314 |
| - |
315 | 351 | def withparent(meth):
|
316 | 352 | """Helper wrapper that passes calls to parent's instance"""
|
317 | 353 | def wrapped(self, *args, **kwargs):
|
@@ -566,7 +602,7 @@ def _build_joined(self, cost, args, vmap):
|
566 | 602 | return args_joined, theano.clone(cost, replace=replace)
|
567 | 603 |
|
568 | 604 |
|
569 |
| -class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta): |
| 605 | +class Model(Factor, WithMemoization, metaclass=ContextMeta, context_class='Model'): |
570 | 606 | """Encapsulates the variables and likelihood factors of a model.
|
571 | 607 |
|
572 | 608 | Model class can be used for creating class based models. To create
|
@@ -660,7 +696,7 @@ def __new__(cls, *args, **kwargs):
|
660 | 696 | if kwargs.get('model') is not None:
|
661 | 697 | instance._parent = kwargs.get('model')
|
662 | 698 | else:
|
663 |
| - instance._parent = cls.get_context() |
| 699 | + instance._parent = cls.get_context(error_if_none=False) |
664 | 700 | theano_config = kwargs.get('theano_config', None)
|
665 | 701 | if theano_config is None or 'compute_test_value' not in theano_config:
|
666 | 702 | theano_config = {'compute_test_value': 'raise'}
|
|
0 commit comments