3
3
import itertools
4
4
import threading
5
5
import warnings
6
- from typing import Optional
6
+ from typing import Optional , TypeVar , Type , List , Union , TYPE_CHECKING , Any , cast
7
+ from sys import modules
7
8
8
9
import numpy as np
9
10
from pandas import Series
@@ -55,10 +56,10 @@ def __call__(self, *args, **kwargs):
55
56
return getattr (self .obj , self .method_name )(* args , ** kwargs )
56
57
57
58
58
- def incorporate_methods (source , destination , methods , default = None ,
59
+ def incorporate_methods (source , destination , methods ,
59
60
wrapper = None , override = False ):
60
61
"""
61
- Add attributes to a destination object which points to
62
+ Add attributes to a destination object which point to
62
63
methods from from a source object.
63
64
64
65
Parameters
@@ -69,8 +70,6 @@ def incorporate_methods(source, destination, methods, default=None,
69
70
The destination object for the methods.
70
71
methods : list of str
71
72
Names of methods to incorporate.
72
- default : object
73
- The value used if the source does not have one of the listed methods.
74
73
wrapper : function
75
74
An optional function to allow the source method to be
76
75
wrapped. Should take the form my_wrapper(source, method_name)
@@ -162,49 +161,131 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
162
161
node_children .update (temp_tree )
163
162
return leaf_nodes , node_parents , node_children
164
163
164
+ T = TypeVar ('T' , bound = 'ContextMeta' )
165
165
166
- class Context :
166
+
167
+ class ContextMeta (type ):
167
168
"""Functionality for objects that put themselves in a context using
168
169
the `with` statement.
169
170
"""
170
- contexts = threading .local ()
171
-
172
- def __enter__ (self ):
173
- type (self ).get_contexts ().append (self )
174
- # self._theano_config is set in Model.__new__
175
- if hasattr (self , '_theano_config' ):
176
- self ._old_theano_config = set_theano_conf (self ._theano_config )
177
- return self
178
-
179
- def __exit__ (self , typ , value , traceback ):
180
- type (self ).get_contexts ().pop ()
181
- # self._theano_config is set in Model.__new__
182
- if hasattr (self , '_old_theano_config' ):
183
- set_theano_conf (self ._old_theano_config )
184
171
185
- @classmethod
186
- def get_contexts (cls ):
187
- # no race-condition here, cls.contexts is a thread-local object
172
+ def __new__ (cls , name , bases , dct , ** kargs ): # pylint: disable=unused-argument
173
+ "Add __enter__ and __exit__ methods to the class."
174
+ def __enter__ (self ):
175
+ self .__class__ .context_class .get_contexts ().append (self )
176
+ # self._theano_config is set in Model.__new__
177
+ if hasattr (self , '_theano_config' ):
178
+ self ._old_theano_config = set_theano_conf (self ._theano_config )
179
+ return self
180
+
181
+ def __exit__ (self , typ , value , traceback ): # pylint: disable=unused-argument
182
+ self .__class__ .context_class .get_contexts ().pop ()
183
+ # self._theano_config is set in Model.__new__
184
+ if hasattr (self , '_old_theano_config' ):
185
+ set_theano_conf (self ._old_theano_config )
186
+
187
+ dct [__enter__ .__name__ ] = __enter__
188
+ dct [__exit__ .__name__ ] = __exit__
189
+
190
+ # We strip off keyword args, per the warning from
191
+ # StackExchange:
192
+ # DO NOT send "**kargs" to "type.__new__". It won't catch them and
193
+ # you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
194
+ return super ().__new__ (cls , name , bases , dct )
195
+
196
+ # FIXME: is there a more elegant way to automatically add methods to the class that
197
+ # are instance methods instead of class methods?
198
+ def __init__ (cls , name , bases , nmspc , context_class : Optional [Type ]= None , ** kwargs ): # pylint: disable=unused-argument
199
+ """Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
200
+ if context_class is not None :
201
+ cls ._context_class = context_class
202
+ super ().__init__ (name , bases , nmspc )
203
+
204
+
205
+
206
+ def get_context (cls , error_if_none = True ) -> Optional [T ]:
207
+ """Return the most recently pushed context object of type ``cls``
208
+ on the stack, or ``None``. If ``error_if_none`` is True (default),
209
+ raise a ``TypeError`` instead of returning ``None``."""
210
+ idx = - 1
211
+ while True :
212
+ try :
213
+ candidate = cls .get_contexts ()[idx ] # type: Optional[T]
214
+ except IndexError as e :
215
+ # Calling code expects to get a TypeError if the entity
216
+ # is unfound, and there's too much to fix.
217
+ if error_if_none :
218
+ raise TypeError ("No %s on context stack" % str (cls ))
219
+ return None
220
+ return candidate
221
+ idx = idx - 1
222
+
223
+ def get_contexts (cls ) -> List [T ]:
224
+ """Return a stack of context instances for the ``context_class``
225
+ of ``cls``."""
226
+ # This lazily creates the context class's contexts
227
+ # thread-local object, as needed. This seems inelegant to me,
228
+ # but since the context class is not guaranteed to exist when
229
+ # the metaclass is being instantiated, I couldn't figure out a
230
+ # better way. [2019/10/11:rpg]
231
+
232
+ # no race-condition here, contexts is a thread-local object
188
233
# be sure not to override contexts in a subclass however!
189
- if not hasattr (cls .contexts , 'stack' ):
190
- cls .contexts .stack = []
191
- return cls .contexts .stack
192
-
193
- @classmethod
194
- def get_context (cls ):
195
- """Return the deepest context on the stack."""
196
- try :
197
- return cls .get_contexts ()[- 1 ]
198
- except IndexError :
199
- raise TypeError ("No context on context stack" )
234
+ context_class = cls .context_class
235
+ assert isinstance (context_class , type ), \
236
+ "Name of context class, %s was not resolvable to a class" % context_class
237
+ if not hasattr (context_class , 'contexts' ):
238
+ context_class .contexts = threading .local ()
239
+
240
+ contexts = context_class .contexts
241
+
242
+ if not hasattr (contexts , 'stack' ):
243
+ contexts .stack = []
244
+ return contexts .stack
245
+
246
+ # the following complex property accessor is necessary because the
247
+ # context_class may not have been created at the point it is
248
+ # specified, so the context_class may be a class *name* rather
249
+ # than a class.
250
+ @property
251
+ def context_class (cls ) -> Type :
252
+ def resolve_type (c : Union [Type , str ]) -> Type :
253
+ if isinstance (c , str ):
254
+ c = getattr (modules [cls .__module__ ], c )
255
+ if isinstance (c , type ):
256
+ return c
257
+ raise ValueError ("Cannot resolve context class %s" % c )
258
+ assert cls is not None
259
+ if isinstance (cls ._context_class , str ):
260
+ cls ._context_class = resolve_type (cls ._context_class )
261
+ if not isinstance (cls ._context_class , (str , type )):
262
+ raise ValueError ("Context class for %s, %s, is not of the right type" % \
263
+ (cls .__name__ , cls ._context_class ))
264
+ return cls ._context_class
265
+
266
+ # Inherit context class from parent
267
+ def __init_subclass__ (cls , ** kwargs ):
268
+ super ().__init_subclass__ (** kwargs )
269
+ cls .context_class = super ().context_class
270
+
271
+ # Initialize object in its own context...
272
+ # Merged from InitContextMeta in the original.
273
+ def __call__ (cls , * args , ** kwargs ):
274
+ instance = cls .__new__ (cls , * args , ** kwargs )
275
+ with instance : # appends context
276
+ instance .__init__ (* args , ** kwargs )
277
+ return instance
200
278
201
279
202
280
def modelcontext (model : Optional ['Model' ]) -> 'Model' :
203
- """return the given model or try to find it in the context if there was
204
- none supplied.
281
+ """
282
+ Return the given model or, if none was supplied, try to find one in
283
+ the context stack.
205
284
"""
206
285
if model is None :
207
- return Model .get_context ()
286
+ model = Model .get_context (error_if_none = False )
287
+ if model is None :
288
+ raise ValueError ("No model on context stack." )
208
289
return model
209
290
210
291
@@ -292,15 +373,6 @@ def logp_nojact(self):
292
373
return logp
293
374
294
375
295
- class InitContextMeta (type ):
296
- """Metaclass that executes `__init__` of instance in it's context"""
297
- def __call__ (cls , * args , ** kwargs ):
298
- instance = cls .__new__ (cls , * args , ** kwargs )
299
- with instance : # appends context
300
- instance .__init__ (* args , ** kwargs )
301
- return instance
302
-
303
-
304
376
def withparent (meth ):
305
377
"""Helper wrapper that passes calls to parent's instance"""
306
378
def wrapped (self , * args , ** kwargs ):
@@ -346,11 +418,18 @@ def __setitem__(self, key, value):
346
418
' able to determine '
347
419
'appropriate logic for it' )
348
420
349
- def __imul__ (self , other ):
421
+ # Added this because mypy didn't like having __imul__ without __mul__
422
+ # This is my best guess about what this should do. I might be happier
423
+ # to kill both of these if they are not used.
424
+ def __mul__ (self , other ) -> 'treelist' :
425
+ return cast ('treelist' , list .__mul__ (self , other ))
426
+
427
+ def __imul__ (self , other ) -> 'treelist' :
350
428
t0 = len (self )
351
429
list .__imul__ (self , other )
352
430
if self .parent is not None :
353
431
self .parent .extend (self [t0 :])
432
+ return self # python spec says should return the result.
354
433
355
434
356
435
class treedict (dict ):
@@ -555,7 +634,7 @@ def _build_joined(self, cost, args, vmap):
555
634
return args_joined , theano .clone (cost , replace = replace )
556
635
557
636
558
- class Model (Context , Factor , WithMemoization , metaclass = InitContextMeta ):
637
+ class Model (Factor , WithMemoization , metaclass = ContextMeta , context_class = 'Model' ):
559
638
"""Encapsulates the variables and likelihood factors of a model.
560
639
561
640
Model class can be used for creating class based models. To create
@@ -643,15 +722,18 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
643
722
CustomModel(mean=1, name='first')
644
723
CustomModel(mean=2, name='second')
645
724
"""
725
+
726
+ if TYPE_CHECKING :
727
+ def __enter__ (self : 'Model' ) -> 'Model' : ...
728
+ def __exit__ (self : 'Model' , * exc : Any ) -> bool : ...
729
+
646
730
def __new__ (cls , * args , ** kwargs ):
647
731
# resolves the parent instance
648
732
instance = super ().__new__ (cls )
649
733
if kwargs .get ('model' ) is not None :
650
734
instance ._parent = kwargs .get ('model' )
651
- elif cls .get_contexts ():
652
- instance ._parent = cls .get_contexts ()[- 1 ]
653
735
else :
654
- instance ._parent = None
736
+ instance ._parent = cls . get_context ( error_if_none = False )
655
737
theano_config = kwargs .get ('theano_config' , None )
656
738
if theano_config is None or 'compute_test_value' not in theano_config :
657
739
theano_config = {'compute_test_value' : 'raise' }
@@ -694,7 +776,7 @@ def root(self):
694
776
def isroot (self ):
695
777
return self .parent is None
696
778
697
- @property
779
+ @property # type: ignore -- mypy can't handle decorated types.
698
780
@memoize (bound = True )
699
781
def bijection (self ):
700
782
vars = inputvars (self .vars )
0 commit comments