Skip to content

Commit 8e89828

Browse files
Prevent unbound trace due to type hints (#6809)
* Prevent unbound trace * move enter and exit into metaclass * simpler, a bit more correct approach * add comment
1 parent 659f177 commit 8e89828

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

pymc/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from pytensor.tensor.random.type import RandomType
5151
from pytensor.tensor.sharedvar import ScalarSharedVariable
5252
from pytensor.tensor.var import TensorConstant, TensorVariable
53+
from typing_extensions import Self
5354

5455
from pymc.blocking import DictToArrayBijection, RaveledVars
5556
from pymc.data import GenTensorVariable, is_minibatch
@@ -213,7 +214,9 @@ def __init_subclass__(cls, **kwargs):
213214
# Initialize object in its own context...
214215
# Merged from InitContextMeta in the original.
215216
def __call__(cls, *args, **kwargs):
216-
instance = cls.__new__(cls, *args, **kwargs)
217+
# We type hint Model here so type checkers understand that Model is a context manager.
218+
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
219+
instance: "Model" = cls.__new__(cls, *args, **kwargs)
217220
with instance: # appends context
218221
instance.__init__(*args, **kwargs)
219222
return instance
@@ -478,10 +481,10 @@ def __init__(self, mean=0, sigma=1, name=''):
478481

479482
if TYPE_CHECKING:
480483

481-
def __enter__(self: "Model") -> "Model":
484+
def __enter__(self: Self) -> Self:
482485
...
483486

484-
def __exit__(self: "Model", *exc: Any) -> bool:
487+
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
485488
...
486489

487490
def __new__(cls, *args, **kwargs):

0 commit comments

Comments
 (0)