Skip to content

Commit ed55be2

Browse files
authored
Merge pull request #3652 from rpgoldman/fix-context-stack
Fix context stack
2 parents 9c4b740 + 55e6f59 commit ed55be2

File tree

6 files changed

+185
-68
lines changed

6 files changed

+185
-68
lines changed

pymc3/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Dict, List, Any
12
from copy import copy
23
import io
34
import os
@@ -232,7 +233,7 @@ class Minibatch(tt.TensorVariable):
232233
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
233234
"""
234235

235-
RNG = collections.defaultdict(list)
236+
RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]
236237

237238
@theano.configparser.change_flags(compute_test_value='raise')
238239
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',

pymc3/distributions/distribution.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..memoize import memoize
88
from ..model import (
99
Model, get_named_nodes_and_relations, FreeRV,
10-
ObservedRV, MultiObservedRV, Context, InitContextMeta
10+
ObservedRV, MultiObservedRV, ContextMeta
1111
)
1212
from ..vartypes import string_types, theano_constant
1313
from .shape_utils import (
@@ -449,23 +449,14 @@ def random(self, point=None, size=None, **kwargs):
449449
"Define a custom random method and pass it as kwarg random")
450450

451451

452-
class _DrawValuesContext(Context, metaclass=InitContextMeta):
452+
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
453453
""" A context manager class used while drawing values with draw_values
454454
"""
455455

456456
def __new__(cls, *args, **kwargs):
457457
# resolves the parent instance
458458
instance = super().__new__(cls)
459-
if cls.get_contexts():
460-
potential_parent = cls.get_contexts()[-1]
461-
# We have to make sure that the context is a _DrawValuesContext
462-
# and not a Model
463-
if isinstance(potential_parent, _DrawValuesContext):
464-
instance._parent = potential_parent
465-
else:
466-
instance._parent = None
467-
else:
468-
instance._parent = None
459+
instance._parent = cls.get_context(error_if_none=False)
469460
return instance
470461

471462
def __init__(self):
@@ -485,7 +476,7 @@ def parent(self):
485476
return self._parent
486477

487478

488-
class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
479+
class _DrawValuesContextBlocker(_DrawValuesContext):
489480
"""
490481
Context manager that starts a new drawn variables context disregarding all
491482
parent contexts. This can be used inside a random method to ensure that

pymc3/model.py

Lines changed: 134 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import itertools
44
import threading
55
import warnings
6-
from typing import Optional
6+
from typing import Optional, TypeVar, Type, List, Union, TYPE_CHECKING, Any, cast
7+
from sys import modules
78

89
import numpy as np
910
from pandas import Series
@@ -55,10 +56,10 @@ def __call__(self, *args, **kwargs):
5556
return getattr(self.obj, self.method_name)(*args, **kwargs)
5657

5758

58-
def incorporate_methods(source, destination, methods, default=None,
59+
def incorporate_methods(source, destination, methods,
5960
wrapper=None, override=False):
6061
"""
61-
Add attributes to a destination object which points to
62+
Add attributes to a destination object which point to
6263
methods from from a source object.
6364
6465
Parameters
@@ -69,8 +70,6 @@ def incorporate_methods(source, destination, methods, default=None,
6970
The destination object for the methods.
7071
methods : list of str
7172
Names of methods to incorporate.
72-
default : object
73-
The value used if the source does not have one of the listed methods.
7473
wrapper : function
7574
An optional function to allow the source method to be
7675
wrapped. Should take the form my_wrapper(source, method_name)
@@ -162,49 +161,131 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162161
node_children.update(temp_tree)
163162
return leaf_nodes, node_parents, node_children
164163

164+
T = TypeVar('T', bound='ContextMeta')
165165

166-
class Context:
166+
167+
class ContextMeta(type):
167168
"""Functionality for objects that put themselves in a context using
168169
the `with` statement.
169170
"""
170-
contexts = threading.local()
171-
172-
def __enter__(self):
173-
type(self).get_contexts().append(self)
174-
# self._theano_config is set in Model.__new__
175-
if hasattr(self, '_theano_config'):
176-
self._old_theano_config = set_theano_conf(self._theano_config)
177-
return self
178-
179-
def __exit__(self, typ, value, traceback):
180-
type(self).get_contexts().pop()
181-
# self._theano_config is set in Model.__new__
182-
if hasattr(self, '_old_theano_config'):
183-
set_theano_conf(self._old_theano_config)
184171

185-
@classmethod
186-
def get_contexts(cls):
187-
# no race-condition here, cls.contexts is a thread-local object
172+
def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
173+
"Add __enter__ and __exit__ methods to the class."
174+
def __enter__(self):
175+
self.__class__.context_class.get_contexts().append(self)
176+
# self._theano_config is set in Model.__new__
177+
if hasattr(self, '_theano_config'):
178+
self._old_theano_config = set_theano_conf(self._theano_config)
179+
return self
180+
181+
def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
182+
self.__class__.context_class.get_contexts().pop()
183+
# self._theano_config is set in Model.__new__
184+
if hasattr(self, '_old_theano_config'):
185+
set_theano_conf(self._old_theano_config)
186+
187+
dct[__enter__.__name__] = __enter__
188+
dct[__exit__.__name__] = __exit__
189+
190+
# We strip off keyword args, per the warning from
191+
# StackExchange:
192+
# DO NOT send "**kargs" to "type.__new__". It won't catch them and
193+
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
194+
return super().__new__(cls, name, bases, dct)
195+
196+
# FIXME: is there a more elegant way to automatically add methods to the class that
197+
# are instance methods instead of class methods?
198+
def __init__(cls, name, bases, nmspc, context_class: Optional[Type]=None, **kwargs): # pylint: disable=unused-argument
199+
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
200+
if context_class is not None:
201+
cls._context_class = context_class
202+
super().__init__(name, bases, nmspc)
203+
204+
205+
206+
def get_context(cls, error_if_none=True) -> Optional[T]:
207+
"""Return the most recently pushed context object of type ``cls``
208+
on the stack, or ``None``. If ``error_if_none`` is True (default),
209+
raise a ``TypeError`` instead of returning ``None``."""
210+
idx = -1
211+
while True:
212+
try:
213+
candidate = cls.get_contexts()[idx] # type: Optional[T]
214+
except IndexError as e:
215+
# Calling code expects to get a TypeError if the entity
216+
# is unfound, and there's too much to fix.
217+
if error_if_none:
218+
raise TypeError("No %s on context stack"%str(cls))
219+
return None
220+
return candidate
221+
idx = idx - 1
222+
223+
def get_contexts(cls) -> List[T]:
224+
"""Return a stack of context instances for the ``context_class``
225+
of ``cls``."""
226+
# This lazily creates the context class's contexts
227+
# thread-local object, as needed. This seems inelegant to me,
228+
# but since the context class is not guaranteed to exist when
229+
# the metaclass is being instantiated, I couldn't figure out a
230+
# better way. [2019/10/11:rpg]
231+
232+
# no race-condition here, contexts is a thread-local object
188233
# be sure not to override contexts in a subclass however!
189-
if not hasattr(cls.contexts, 'stack'):
190-
cls.contexts.stack = []
191-
return cls.contexts.stack
192-
193-
@classmethod
194-
def get_context(cls):
195-
"""Return the deepest context on the stack."""
196-
try:
197-
return cls.get_contexts()[-1]
198-
except IndexError:
199-
raise TypeError("No context on context stack")
234+
context_class = cls.context_class
235+
assert isinstance(context_class, type), \
236+
"Name of context class, %s was not resolvable to a class"%context_class
237+
if not hasattr(context_class, 'contexts'):
238+
context_class.contexts = threading.local()
239+
240+
contexts = context_class.contexts
241+
242+
if not hasattr(contexts, 'stack'):
243+
contexts.stack = []
244+
return contexts.stack
245+
246+
# the following complex property accessor is necessary because the
247+
# context_class may not have been created at the point it is
248+
# specified, so the context_class may be a class *name* rather
249+
# than a class.
250+
@property
251+
def context_class(cls) -> Type:
252+
def resolve_type(c: Union[Type, str]) -> Type:
253+
if isinstance(c, str):
254+
c = getattr(modules[cls.__module__], c)
255+
if isinstance(c, type):
256+
return c
257+
raise ValueError("Cannot resolve context class %s"%c)
258+
assert cls is not None
259+
if isinstance(cls._context_class, str):
260+
cls._context_class = resolve_type(cls._context_class)
261+
if not isinstance(cls._context_class, (str, type)):
262+
raise ValueError("Context class for %s, %s, is not of the right type"%\
263+
(cls.__name__, cls._context_class))
264+
return cls._context_class
265+
266+
# Inherit context class from parent
267+
def __init_subclass__(cls, **kwargs):
268+
super().__init_subclass__(**kwargs)
269+
cls.context_class = super().context_class
270+
271+
# Initialize object in its own context...
272+
# Merged from InitContextMeta in the original.
273+
def __call__(cls, *args, **kwargs):
274+
instance = cls.__new__(cls, *args, **kwargs)
275+
with instance: # appends context
276+
instance.__init__(*args, **kwargs)
277+
return instance
200278

201279

202280
def modelcontext(model: Optional['Model']) -> 'Model':
203-
"""return the given model or try to find it in the context if there was
204-
none supplied.
281+
"""
282+
Return the given model or, if none was supplied, try to find one in
283+
the context stack.
205284
"""
206285
if model is None:
207-
return Model.get_context()
286+
model = Model.get_context(error_if_none=False)
287+
if model is None:
288+
raise ValueError("No model on context stack.")
208289
return model
209290

210291

@@ -292,15 +373,6 @@ def logp_nojact(self):
292373
return logp
293374

294375

295-
class InitContextMeta(type):
296-
"""Metaclass that executes `__init__` of instance in it's context"""
297-
def __call__(cls, *args, **kwargs):
298-
instance = cls.__new__(cls, *args, **kwargs)
299-
with instance: # appends context
300-
instance.__init__(*args, **kwargs)
301-
return instance
302-
303-
304376
def withparent(meth):
305377
"""Helper wrapper that passes calls to parent's instance"""
306378
def wrapped(self, *args, **kwargs):
@@ -346,11 +418,18 @@ def __setitem__(self, key, value):
346418
' able to determine '
347419
'appropriate logic for it')
348420

349-
def __imul__(self, other):
421+
# Added this because mypy didn't like having __imul__ without __mul__
422+
# This is my best guess about what this should do. I might be happier
423+
# to kill both of these if they are not used.
424+
def __mul__ (self, other) -> 'treelist':
425+
return cast('treelist', list.__mul__(self, other))
426+
427+
def __imul__(self, other) -> 'treelist':
350428
t0 = len(self)
351429
list.__imul__(self, other)
352430
if self.parent is not None:
353431
self.parent.extend(self[t0:])
432+
return self # python spec says should return the result.
354433

355434

356435
class treedict(dict):
@@ -555,7 +634,7 @@ def _build_joined(self, cost, args, vmap):
555634
return args_joined, theano.clone(cost, replace=replace)
556635

557636

558-
class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta):
637+
class Model(Factor, WithMemoization, metaclass=ContextMeta, context_class='Model'):
559638
"""Encapsulates the variables and likelihood factors of a model.
560639
561640
Model class can be used for creating class based models. To create
@@ -643,15 +722,18 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
643722
CustomModel(mean=1, name='first')
644723
CustomModel(mean=2, name='second')
645724
"""
725+
726+
if TYPE_CHECKING:
727+
def __enter__(self: 'Model') -> 'Model': ...
728+
def __exit__(self: 'Model', *exc: Any) -> bool: ...
729+
646730
def __new__(cls, *args, **kwargs):
647731
# resolves the parent instance
648732
instance = super().__new__(cls)
649733
if kwargs.get('model') is not None:
650734
instance._parent = kwargs.get('model')
651-
elif cls.get_contexts():
652-
instance._parent = cls.get_contexts()[-1]
653735
else:
654-
instance._parent = None
736+
instance._parent = cls.get_context(error_if_none=False)
655737
theano_config = kwargs.get('theano_config', None)
656738
if theano_config is None or 'compute_test_value' not in theano_config:
657739
theano_config = {'compute_test_value': 'raise'}
@@ -694,7 +776,7 @@ def root(self):
694776
def isroot(self):
695777
return self.parent is None
696778

697-
@property
779+
@property # type: ignore -- mypy can't handle decorated types.
698780
@memoize(bound=True)
699781
def bijection(self):
700782
vars = inputvars(self.vars)

pymc3/tests/test_data_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_sample_after_set_data(self):
7777
atol=1e-1)
7878

7979
def test_creation_of_data_outside_model_context(self):
80-
with pytest.raises(TypeError) as error:
80+
with pytest.raises((IndexError, TypeError)) as error:
8181
pm.Data('data', [1.1, 2.2, 3.3])
8282
error.match('No model on context stack')
8383

pymc3/tests/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,15 @@ def test_setattr_properly_works(self):
5252

5353
def test_context_passes_vars_to_parent_model(self):
5454
with pm.Model() as model:
55+
assert pm.model.modelcontext(None) == model
56+
assert pm.Model.get_context() == model
5557
# a set of variables is created
56-
NewModel()
58+
nm = NewModel()
59+
assert pm.Model.get_context() == model
5760
# another set of variables are created but with prefix 'another'
5861
usermodel2 = NewModel(name='another')
62+
assert pm.Model.get_context() == model
63+
assert usermodel2._parent == model
5964
# you can enter in a context with submodel
6065
with usermodel2:
6166
usermodel2.Var('v3', pm.Normal.dist())

0 commit comments

Comments
 (0)