Skip to content

Commit 8185443

Browse files
committed
Allow caller control of get_context() errors.
Some code expects `Model.get_context()` to error out if the return is None. Some code expects `_DrawValuesHandler.get_context()` to return None. To harmonize this, added optional argument to control error signaling. Defaults to signaling error, because this seems to be the general case.
1 parent fe0b52f commit 8185443

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pymc3/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ class _DrawValuesContext(Context, metaclass=InitContextMeta):
456456
def __new__(cls, *args, **kwargs):
457457
# resolves the parent instance
458458
instance = super().__new__(cls)
459-
instance._parent = cls.get_context()
459+
instance._parent = cls.get_context(error_if_none=False)
460460
return instance
461461

462462
def __init__(self):

pymc3/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,16 @@ def get_contexts(cls: Type[T]) -> List[T]:
192192
return cls.contexts.stack
193193

194194
@classmethod
195-
def get_context(cls: Type[T]) -> Optional[T]:
195+
def get_context(cls: Type[T], error_if_none=True) -> Optional[T]:
196196
"""Return the most recently pushed context object of type ``cls``
197197
on the stack, or ``None``."""
198198
idx = -1
199199
while True:
200200
try:
201201
candidate = cls.get_contexts()[idx]
202-
except IndexError:
202+
except IndexError as e:
203+
if error_if_none:
204+
raise e
203205
return None
204206
if isinstance(candidate, cls):
205207
return candidate
@@ -210,7 +212,7 @@ def modelcontext(model: Optional['Model']) -> Optional['Model']:
210212
none supplied.
211213
"""
212214
if model is None:
213-
found: Optional['Model'] = Model.get_context()
215+
found: Optional['Model'] = Model.get_context(error_if_none=False)
214216
if found is None:
215217
raise ValueError("No pymc3 model object on context stack.")
216218
return found

0 commit comments

Comments
 (0)