Skip to content

Commit ccbec76

Browse files
committed
Test correct behavior of context stack.
1 parent 3518ba9 commit ccbec76

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

pymc3/tests/test_modelcontext.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import threading
2+
from pytest import raises
23
from pymc3 import Model, Normal
4+
from pymc3.distributions.distribution import _DrawValuesContext, _DrawValuesContextBlocker
5+
from pymc3.model import modelcontext
36

47

58
class TestModelContext:
@@ -42,3 +45,34 @@ def make_model_b():
4245
list(modelA.named_vars),
4346
list(modelB.named_vars),
4447
) == (['a'],['b'])
48+
49+
def test_mixed_contexts():
50+
modelA = Model()
51+
modelB = Model()
52+
with raises((ValueError, TypeError)):
53+
modelcontext(None)
54+
with modelA:
55+
with modelB:
56+
assert Model.get_context() == modelB
57+
assert modelcontext(None) == modelB
58+
dvc = _DrawValuesContext()
59+
with dvc:
60+
assert Model.get_context() == modelB
61+
assert modelcontext(None) == modelB
62+
assert _DrawValuesContext.get_context() == dvc
63+
dvcb = _DrawValuesContextBlocker()
64+
with dvcb:
65+
assert _DrawValuesContext.get_context() == dvcb
66+
assert _DrawValuesContextBlocker.get_context() == dvcb
67+
assert _DrawValuesContext.get_context() == dvc
68+
assert _DrawValuesContextBlocker.get_context() is None
69+
assert Model.get_context() == modelB
70+
assert modelcontext(None) == modelB
71+
assert _DrawValuesContext.get_context() is None
72+
assert Model.get_context() == modelB
73+
assert modelcontext(None) == modelB
74+
assert Model.get_context() == modelA
75+
assert modelcontext(None) == modelA
76+
assert Model.get_context() is None
77+
with raises((ValueError, TypeError)):
78+
modelcontext(None)

0 commit comments

Comments
 (0)