Skip to content

Commit fe0b52f

Browse files
committed
Check return value type in context stack.
Previously, we simply assumed that the top of the context stack would be of the right type, but that wasn't guaranteed (it just happened to work). This ensures that we don't return the wrong type from Context.get_contexts().
1 parent ccbec76 commit fe0b52f

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

pymc3/distributions/distribution.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -456,16 +456,7 @@ class _DrawValuesContext(Context, metaclass=InitContextMeta):
456456
def __new__(cls, *args, **kwargs):
457457
# resolves the parent instance
458458
instance = super().__new__(cls)
459-
if cls.get_contexts():
460-
potential_parent = cls.get_contexts()[-1]
461-
# We have to make sure that the context is a _DrawValuesContext
462-
# and not a Model
463-
if isinstance(potential_parent, _DrawValuesContext):
464-
instance._parent = potential_parent
465-
else:
466-
instance._parent = None
467-
else:
468-
instance._parent = None
459+
instance._parent = cls.get_context()
469460
return instance
470461

471462
def __init__(self):

pymc3/model.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
import threading
55
import warnings
6-
from typing import Optional
6+
from typing import Optional, Tuple, TypeVar, Type, List
77

88
import numpy as np
99
from pandas import Series
@@ -162,6 +162,7 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162162
node_children.update(temp_tree)
163163
return leaf_nodes, node_parents, node_children
164164

165+
T = TypeVar('T', bound='Context')
165166

166167
class Context:
167168
"""Functionality for objects that put themselves in a context using
@@ -183,28 +184,36 @@ def __exit__(self, typ, value, traceback):
183184
set_theano_conf(self._old_theano_config)
184185

185186
@classmethod
186-
def get_contexts(cls):
187+
def get_contexts(cls: Type[T]) -> List[T]:
187188
# no race-condition here, cls.contexts is a thread-local object
188189
# be sure not to override contexts in a subclass however!
189190
if not hasattr(cls.contexts, 'stack'):
190191
cls.contexts.stack = []
191192
return cls.contexts.stack
192193

193194
@classmethod
194-
def get_context(cls):
195-
"""Return the deepest context on the stack."""
196-
try:
197-
return cls.get_contexts()[-1]
198-
except IndexError:
199-
raise TypeError("No context on context stack")
200-
201-
202-
def modelcontext(model: Optional['Model']) -> 'Model':
195+
def get_context(cls: Type[T]) -> Optional[T]:
196+
"""Return the most recently pushed context object of type ``cls``
197+
on the stack, or ``None``."""
198+
idx = -1
199+
while True:
200+
try:
201+
candidate = cls.get_contexts()[idx]
202+
except IndexError:
203+
return None
204+
if isinstance(candidate, cls):
205+
return candidate
206+
idx = idx - 1
207+
208+
def modelcontext(model: Optional['Model']) -> Optional['Model']:
203209
"""return the given model or try to find it in the context if there was
204210
none supplied.
205211
"""
206212
if model is None:
207-
return Model.get_context()
213+
found: Optional['Model'] = Model.get_context()
214+
if found is None:
215+
raise ValueError("No pymc3 model object on context stack.")
216+
return found
208217
return model
209218

210219

@@ -648,10 +657,8 @@ def __new__(cls, *args, **kwargs):
648657
instance = super().__new__(cls)
649658
if kwargs.get('model') is not None:
650659
instance._parent = kwargs.get('model')
651-
elif cls.get_contexts():
652-
instance._parent = cls.get_contexts()[-1]
653660
else:
654-
instance._parent = None
661+
instance._parent = cls.get_context()
655662
theano_config = kwargs.get('theano_config', None)
656663
if theano_config is None or 'compute_test_value' not in theano_config:
657664
theano_config = {'compute_test_value': 'raise'}

0 commit comments

Comments
 (0)