|
50 | 50 | from pytensor.tensor.random.type import RandomType
|
51 | 51 | from pytensor.tensor.sharedvar import ScalarSharedVariable
|
52 | 52 | from pytensor.tensor.var import TensorConstant, TensorVariable
|
| 53 | +from typing_extensions import Self |
53 | 54 |
|
54 | 55 | from pymc.blocking import DictToArrayBijection, RaveledVars
|
55 | 56 | from pymc.data import GenTensorVariable, is_minibatch
|
@@ -213,7 +214,9 @@ def __init_subclass__(cls, **kwargs):
|
213 | 214 | # Initialize object in its own context...
|
214 | 215 | # Merged from InitContextMeta in the original.
|
215 | 216 | def __call__(cls, *args, **kwargs):
|
216 |
| - instance = cls.__new__(cls, *args, **kwargs) |
| 217 | + # We type hint Model here so type checkers understand that Model is a context manager. |
| 218 | + # This metaclass is only used for Model, so this is safe to do. See #6809 for more info. |
| 219 | + instance: "Model" = cls.__new__(cls, *args, **kwargs) |
217 | 220 | with instance: # appends context
|
218 | 221 | instance.__init__(*args, **kwargs)
|
219 | 222 | return instance
|
@@ -478,10 +481,10 @@ def __init__(self, mean=0, sigma=1, name=''):
|
478 | 481 |
|
479 | 482 | if TYPE_CHECKING:
|
480 | 483 |
|
481 |
| - def __enter__(self: "Model") -> "Model": |
| 484 | + def __enter__(self: Self) -> Self: |
482 | 485 | ...
|
483 | 486 |
|
484 |
| - def __exit__(self: "Model", *exc: Any) -> bool: |
| 487 | + def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: |
485 | 488 | ...
|
486 | 489 |
|
487 | 490 | def __new__(cls, *args, **kwargs):
|
|
0 commit comments