Skip to content

Commit 501146f

Browse files
committed
Metaclass-based solution.
Alternative approach to solving the context stack issues. The previous version was a dead end, because there was no effective way to check the appropriateness of the classes on a single context stack. So in this version, I split the context stacks into two, one for the pm.Models, and one for pm.distributions.distributions._DrawValuesContext. This works better, but involved replacing the Context *class* with the ContextMeta parameterized metaclass.
1 parent 8185443 commit 501146f

File tree

5 files changed

+97
-52
lines changed

5 files changed

+97
-52
lines changed

pymc3/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..memoize import memoize
88
from ..model import (
99
Model, get_named_nodes_and_relations, FreeRV,
10-
ObservedRV, MultiObservedRV, Context, InitContextMeta
10+
ObservedRV, MultiObservedRV, ContextMeta
1111
)
1212
from ..vartypes import string_types, theano_constant
1313
from .shape_utils import (
@@ -449,7 +449,7 @@ def random(self, point=None, size=None, **kwargs):
449449
"Define a custom random method and pass it as kwarg random")
450450

451451

452-
class _DrawValuesContext(Context, metaclass=InitContextMeta):
452+
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
453453
""" A context manager class used while drawing values with draw_values
454454
"""
455455

@@ -476,7 +476,7 @@ def parent(self):
476476
return self._parent
477477

478478

479-
class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
479+
class _DrawValuesContextBlocker(_DrawValuesContext):
480480
"""
481481
Context manager that starts a new drawn variables context disregarding all
482482
parent contexts. This can be used inside a random method to ensure that

pymc3/model.py

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import itertools
44
import threading
55
import warnings
6-
from typing import Optional, Tuple, TypeVar, Type, List
6+
from typing import Optional, Tuple, TypeVar, Type, List, Union
7+
from sys import modules
78

89
import numpy as np
910
from pandas import Series
@@ -162,59 +163,103 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162163
node_children.update(temp_tree)
163164
return leaf_nodes, node_parents, node_children
164165

165-
T = TypeVar('T', bound='Context')
166+
T = TypeVar('T', bound='ContextMeta')
166167

167-
class Context:
168+
169+
class ContextMeta(type):
168170
"""Functionality for objects that put themselves in a context using
169171
the `with` statement.
170172
"""
171-
contexts = threading.local()
172-
173-
def __enter__(self):
174-
type(self).get_contexts().append(self)
175-
# self._theano_config is set in Model.__new__
176-
if hasattr(self, '_theano_config'):
177-
self._old_theano_config = set_theano_conf(self._theano_config)
178-
return self
179-
180-
def __exit__(self, typ, value, traceback):
181-
type(self).get_contexts().pop()
182-
# self._theano_config is set in Model.__new__
183-
if hasattr(self, '_old_theano_config'):
184-
set_theano_conf(self._old_theano_config)
185-
186-
@classmethod
187-
def get_contexts(cls: Type[T]) -> List[T]:
188-
# no race-condition here, cls.contexts is a thread-local object
189-
# be sure not to override contexts in a subclass however!
190-
if not hasattr(cls.contexts, 'stack'):
191-
cls.contexts.stack = []
192-
return cls.contexts.stack
193-
194-
@classmethod
195-
def get_context(cls: Type[T], error_if_none=True) -> Optional[T]:
173+
_context_class = None # type: Union[Type, str]
174+
175+
def __new__(cls, name, bases, dct, **kargs):
176+
# DO NOT send "**kargs" to "type.__new__". It won't catch them and
177+
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
178+
# dct['get_context'] = classmethod(_get_context)
179+
# dct['get_contexts'] = classmethod(_get_contexts)
180+
return super().__new__(cls, name, bases, dct)
181+
182+
def __init__(cls, name, bases, nmspc, context_class: Optional[Type]=None, **kwargs):
183+
if context_class is not None:
184+
cls._context_class = context_class
185+
super().__init__(name, bases, nmspc)
186+
cls.contexts = threading.local()
187+
def __enter__(self):
188+
self.__class__.context_class.get_contexts().append(self)
189+
# self._theano_config is set in Model.__new__
190+
if hasattr(self, '_theano_config'):
191+
self._old_theano_config = set_theano_conf(self._theano_config)
192+
return self
193+
194+
def __exit__(self, typ, value, traceback):
195+
self.__class__.context_class.get_contexts().pop()
196+
# self._theano_config is set in Model.__new__
197+
if hasattr(self, '_old_theano_config'):
198+
set_theano_conf(self._old_theano_config)
199+
200+
cls.__enter__ = __enter__
201+
cls.__exit__ = __exit__
202+
203+
def get_context(cls, error_if_none=True) -> Optional[T]:
196204
"""Return the most recently pushed context object of type ``cls``
197205
on the stack, or ``None``."""
198206
idx = -1
199207
while True:
200208
try:
201-
candidate = cls.get_contexts()[idx]
209+
candidate = cls.get_contexts()[idx] # type: Optional[T]
202210
except IndexError as e:
211+
# Calling code expects to get a TypeError if the entity
212+
# is unfound, and there's too much to fix.
203213
if error_if_none:
204-
raise e
214+
raise TypeError("No %s on context stack"%str(cls))
205215
return None
206-
if isinstance(candidate, cls):
207-
return candidate
216+
return candidate
208217
idx = idx - 1
209218

219+
def get_contexts(cls) -> List[T]:
220+
# no race-condition here, cls.contexts is a thread-local object
221+
# be sure not to override contexts in a subclass however!
222+
if not hasattr(cls.context_class, 'stack'):
223+
cls.context_class.stack = []
224+
return cls.context_class.stack
225+
226+
@property
227+
def context_class(cls) -> Type:
228+
def resolve_type(c: Union[Type, str]) -> Type:
229+
if isinstance(c, str):
230+
c = getattr(modules[cls.__module__], c)
231+
if isinstance(c, type):
232+
return c
233+
raise ValueError("Cannot resolve context class %s"%c)
234+
assert cls is not None
235+
if isinstance(cls._context_class, str):
236+
cls._context_class = resolve_type(cls._context_class)
237+
if not isinstance(cls._context_class, (str, type)):
238+
raise ValueError("Context class for %s, %s, is not of the right type"%\
239+
(cls.__name__, cls._context_class))
240+
return cls._context_class
241+
242+
def __init_subclass__(cls, **kwargs):
243+
super().__init_subclass__(**kwargs)
244+
cls.context_class = super().context_class
245+
246+
# Initialize object in its own context...
247+
def __call__(cls, *args, **kwargs):
248+
instance = cls.__new__(cls, *args, **kwargs)
249+
with instance: # appends context
250+
instance.__init__(*args, **kwargs)
251+
return instance
252+
253+
254+
210255
def modelcontext(model: Optional['Model']) -> Optional['Model']:
211256
"""return the given model or try to find it in the context if there was
212257
none supplied.
213258
"""
214259
if model is None:
215260
found: Optional['Model'] = Model.get_context(error_if_none=False)
216261
if found is None:
217-
raise ValueError("No pymc3 model object on context stack.")
262+
raise ValueError("No model on context stack.")
218263
return found
219264
return model
220265

@@ -303,15 +348,6 @@ def logp_nojact(self):
303348
return logp
304349

305350

306-
class InitContextMeta(type):
307-
"""Metaclass that executes `__init__` of instance in it's context"""
308-
def __call__(cls, *args, **kwargs):
309-
instance = cls.__new__(cls, *args, **kwargs)
310-
with instance: # appends context
311-
instance.__init__(*args, **kwargs)
312-
return instance
313-
314-
315351
def withparent(meth):
316352
"""Helper wrapper that passes calls to parent's instance"""
317353
def wrapped(self, *args, **kwargs):
@@ -566,7 +602,7 @@ def _build_joined(self, cost, args, vmap):
566602
return args_joined, theano.clone(cost, replace=replace)
567603

568604

569-
class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta):
605+
class Model(Factor, WithMemoization, metaclass=ContextMeta, context_class='Model'):
570606
"""Encapsulates the variables and likelihood factors of a model.
571607
572608
Model class can be used for creating class based models. To create
@@ -660,7 +696,7 @@ def __new__(cls, *args, **kwargs):
660696
if kwargs.get('model') is not None:
661697
instance._parent = kwargs.get('model')
662698
else:
663-
instance._parent = cls.get_context()
699+
instance._parent = cls.get_context(error_if_none=False)
664700
theano_config = kwargs.get('theano_config', None)
665701
if theano_config is None or 'compute_test_value' not in theano_config:
666702
theano_config = {'compute_test_value': 'raise'}

pymc3/tests/test_data_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_sample_after_set_data(self):
7777
atol=1e-1)
7878

7979
def test_creation_of_data_outside_model_context(self):
80-
with pytest.raises(TypeError) as error:
80+
with pytest.raises((IndexError, TypeError)) as error:
8181
pm.Data('data', [1.1, 2.2, 3.3])
8282
error.match('No model on context stack')
8383

pymc3/tests/test_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,15 @@ def test_setattr_properly_works(self):
5252

5353
def test_context_passes_vars_to_parent_model(self):
5454
with pm.Model() as model:
55+
assert pm.model.modelcontext(None) == model
56+
assert pm.Model.get_context() == model
5557
# a set of variables is created
56-
NewModel()
58+
nm = NewModel()
59+
assert pm.Model.get_context() == model
5760
# another set of variables are created but with prefix 'another'
5861
usermodel2 = NewModel(name='another')
62+
assert pm.Model.get_context() == model
63+
assert usermodel2._parent == model
5964
# you can enter in a context with submodel
6065
with usermodel2:
6166
usermodel2.Var('v3', pm.Normal.dist())

pymc3/tests/test_modelcontext.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@ def test_mixed_contexts():
6565
assert _DrawValuesContext.get_context() == dvcb
6666
assert _DrawValuesContextBlocker.get_context() == dvcb
6767
assert _DrawValuesContext.get_context() == dvc
68-
assert _DrawValuesContextBlocker.get_context() is None
68+
assert _DrawValuesContextBlocker.get_context() is dvc
6969
assert Model.get_context() == modelB
7070
assert modelcontext(None) == modelB
71-
assert _DrawValuesContext.get_context() is None
71+
assert _DrawValuesContext.get_context(error_if_none=False) is None
72+
with raises(TypeError):
73+
_DrawValuesContext.get_context()
7274
assert Model.get_context() == modelB
7375
assert modelcontext(None) == modelB
7476
assert Model.get_context() == modelA
7577
assert modelcontext(None) == modelA
76-
assert Model.get_context() is None
78+
assert Model.get_context(error_if_none=False) is None
79+
with raises(TypeError):
80+
Model.get_context(error_if_none=True)
7781
with raises((ValueError, TypeError)):
7882
modelcontext(None)

0 commit comments

Comments
 (0)