File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -542,7 +542,11 @@ def __init__(
542
542
name = "" ,
543
543
coords = None ,
544
544
check_bounds = True ,
545
+ * ,
546
+ aesara_config = None ,
547
+ model = None ,
545
548
):
549
+ del aesara_config , model # used in __new__
546
550
self .name = self ._validate_name (name )
547
551
self .check_bounds = check_bounds
548
552
Original file line number Diff line number Diff line change @@ -1014,3 +1014,20 @@ def test_compile_fn():
1014
1014
result_expect = func (state )
1015
1015
1016
1016
np .testing .assert_allclose (result_compute , result_expect )
1017
+
1018
+
1019
+ def test_model_aesara_config ():
1020
+ assert aesara .config .mode != "JAX"
1021
+ with pm .Model (aesara_config = dict (mode = "JAX" )) as model :
1022
+ assert aesara .config .mode == "JAX"
1023
+ assert aesara .config .mode != "JAX"
1024
+
1025
+
1026
+ def test_model_parent_set_programmatically ():
1027
+ with pm .Model () as model :
1028
+ x = pm .Normal ("x" )
1029
+
1030
+ with pm .Model (model = model ):
1031
+ y = pm .Normal ("y" )
1032
+
1033
+ assert "y" in model .named_vars
You can’t perform that action at this time.
0 commit comments