3
3
import itertools
4
4
import threading
5
5
import warnings
6
- from typing import Optional
6
+ from typing import Optional , Tuple , TypeVar , Type , List
7
7
8
8
import numpy as np
9
9
from pandas import Series
@@ -162,6 +162,7 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162
162
node_children .update (temp_tree )
163
163
return leaf_nodes , node_parents , node_children
164
164
165
+ T = TypeVar ('T' , bound = 'Context' )
165
166
166
167
class Context :
167
168
"""Functionality for objects that put themselves in a context using
@@ -183,28 +184,36 @@ def __exit__(self, typ, value, traceback):
183
184
set_theano_conf (self ._old_theano_config )
184
185
185
186
@classmethod
186
- def get_contexts (cls ) :
187
+ def get_contexts (cls : Type [ T ]) -> List [ T ] :
187
188
# no race-condition here, cls.contexts is a thread-local object
188
189
# be sure not to override contexts in a subclass however!
189
190
if not hasattr (cls .contexts , 'stack' ):
190
191
cls .contexts .stack = []
191
192
return cls .contexts .stack
192
193
193
194
@classmethod
194
- def get_context (cls ):
195
- """Return the deepest context on the stack."""
196
- try :
197
- return cls .get_contexts ()[- 1 ]
198
- except IndexError :
199
- raise TypeError ("No context on context stack" )
200
-
201
-
202
- def modelcontext (model : Optional ['Model' ]) -> 'Model' :
195
+ def get_context (cls : Type [T ]) -> Optional [T ]:
196
+ """Return the most recently pushed context object of type ``cls``
197
+ on the stack, or ``None``."""
198
+ idx = - 1
199
+ while True :
200
+ try :
201
+ candidate = cls .get_contexts ()[idx ]
202
+ except IndexError :
203
+ return None
204
+ if isinstance (candidate , cls ):
205
+ return candidate
206
+ idx = idx - 1
207
+
208
+ def modelcontext (model : Optional ['Model' ]) -> Optional ['Model' ]:
203
209
"""return the given model or try to find it in the context if there was
204
210
none supplied.
205
211
"""
206
212
if model is None :
207
- return Model .get_context ()
213
+ found : Optional ['Model' ] = Model .get_context ()
214
+ if found is None :
215
+ raise ValueError ("No pymc3 model object on context stack." )
216
+ return found
208
217
return model
209
218
210
219
@@ -648,10 +657,8 @@ def __new__(cls, *args, **kwargs):
648
657
instance = super ().__new__ (cls )
649
658
if kwargs .get ('model' ) is not None :
650
659
instance ._parent = kwargs .get ('model' )
651
- elif cls .get_contexts ():
652
- instance ._parent = cls .get_contexts ()[- 1 ]
653
660
else :
654
- instance ._parent = None
661
+ instance ._parent = cls . get_context ()
655
662
theano_config = kwargs .get ('theano_config' , None )
656
663
if theano_config is None or 'compute_test_value' not in theano_config :
657
664
theano_config = {'compute_test_value' : 'raise' }
0 commit comments