Skip to content

Commit 8547d6d

Browse files
Fix several typing gaps
1 parent d6f1aad commit 8547d6d

File tree

6 files changed

+39
-29
lines changed

6 files changed

+39
-29
lines changed

pymc/sampling/mcmc.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import warnings
2222

2323
from collections import defaultdict
24-
from typing import Iterator, List, Optional, Sequence, Tuple, Union
24+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
2525

2626
import numpy as np
2727
import pytensor.gradient as tg
2828

2929
from arviz import InferenceData
3030
from fastprogress.fastprogress import progress_bar
31-
from typing_extensions import TypeAlias
31+
from typing_extensions import Protocol, TypeAlias
3232

3333
import pymc as pm
3434

@@ -64,6 +64,13 @@
6464
Step: TypeAlias = Union[BlockedStep, CompoundStep]
6565

6666

67+
class SamplingIteratorCallback(Protocol):
68+
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""
69+
70+
def __call__(self, trace: BaseTrace, draw: Draw):
71+
pass
72+
73+
6774
_log = logging.getLogger("pymc")
6875

6976

@@ -221,7 +228,7 @@ def sample(
221228
cores: Optional[int] = None,
222229
tune: int = 1000,
223230
progressbar: bool = True,
224-
model=None,
231+
model: Optional[Model] = None,
225232
random_seed: RandomState = None,
226233
discard_tuned_samples: bool = True,
227234
compute_convergence_checks: bool = True,
@@ -599,7 +606,7 @@ def sample(
599606

600607
idata = None
601608
if compute_convergence_checks or return_inferencedata:
602-
ikwargs = dict(model=model, save_warmup=not discard_tuned_samples)
609+
ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
603610
if idata_kwargs:
604611
ikwargs.update(idata_kwargs)
605612
idata = pm.to_inference_data(mtrace, **ikwargs)
@@ -655,8 +662,8 @@ def _sample_many(
655662
traces: Sequence[BaseTrace],
656663
start: Sequence[PointType],
657664
random_seed: Optional[Sequence[RandomSeed]],
658-
step,
659-
callback=None,
665+
step: Step,
666+
callback: Optional[SamplingIteratorCallback] = None,
660667
**kwargs,
661668
):
662669
"""Samples all chains sequentially.
@@ -695,7 +702,7 @@ def _sample(
695702
random_seed: RandomSeed,
696703
start: PointType,
697704
draws: int,
698-
step=None,
705+
step: Step,
699706
trace: BaseTrace,
700707
tune: int,
701708
model: Optional[Model] = None,
@@ -760,14 +767,14 @@ def _sample(
760767
def _iter_sample(
761768
*,
762769
draws: int,
763-
step,
770+
step: Step,
764771
start: PointType,
765772
trace: BaseTrace,
766773
chain: int = 0,
767774
tune: int = 0,
768-
model=None,
775+
model: Optional[Model] = None,
769776
random_seed: RandomSeed = None,
770-
callback=None,
777+
callback: Optional[SamplingIteratorCallback] = None,
771778
) -> Iterator[bool]:
772779
"""Generator for sampling one chain. (Used in singleprocess sampling.)
773780
@@ -803,19 +810,13 @@ def _iter_sample(
803810
if random_seed is not None:
804811
np.random.seed(random_seed)
805812

806-
try:
807-
step = CompoundStep(step)
808-
except TypeError:
809-
pass
810-
811813
point = start
812814

813815
try:
814816
step.tune = bool(tune)
815817
if hasattr(step, "reset_tuning"):
816818
step.reset_tuning()
817819
for i in range(draws):
818-
stats = None
819820
diverging = False
820821

821822
if i == 0 and hasattr(step, "iter_count"):
@@ -825,7 +826,7 @@ def _iter_sample(
825826
point, stats = step.step(point)
826827
trace.record(point, stats)
827828
log_warning_stats(stats)
828-
diverging = i > tune and stats and stats[0].get("diverging")
829+
diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") == True)
829830
if callback is not None:
830831
callback(
831832
trace=trace,
@@ -854,8 +855,8 @@ def _mp_sample(
854855
start: Sequence[PointType],
855856
progressbar: bool = True,
856857
traces: Sequence[BaseTrace],
857-
model=None,
858-
callback=None,
858+
model: Optional[Model] = None,
859+
callback: Optional[SamplingIteratorCallback] = None,
859860
mp_ctx=None,
860861
**kwargs,
861862
) -> None:
@@ -884,7 +885,7 @@ def _mp_sample(
884885
A backend instance, or None.
885886
If None, the NDArray backend is used.
886887
model : Model (optional if in ``with`` context)
887-
callback : Callable
888+
callback
888889
A function which gets called for every sample from the trace of a chain. The function is
889890
called with the trace and the current draw and will contain all samples for a single trace.
890891
the ``draw.chain`` argument can be used to determine which of the active chains the sample
@@ -994,7 +995,7 @@ def init_nuts(
994995
init: str = "auto",
995996
chains: int = 1,
996997
n_init: int = 500_000,
997-
model=None,
998+
model: Optional[Model] = None,
998999
random_seed: RandomSeed = None,
9991000
progressbar=True,
10001001
jitter_max_retries: int = 10,

pymc/sampling/population.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from pymc.backends.base import BaseTrace
3030
from pymc.initial_point import PointType
31-
from pymc.model import modelcontext
31+
from pymc.model import Model, modelcontext
3232
from pymc.stats.convergence import log_warning_stats
3333
from pymc.step_methods import CompoundStep
3434
from pymc.step_methods.arraystep import (
@@ -50,13 +50,13 @@
5050

5151
def _sample_population(
5252
*,
53-
initial_points,
53+
initial_points: Sequence[PointType],
5454
draws: int,
5555
start: Sequence[PointType],
5656
random_seed: RandomSeed,
57-
step,
57+
step: Union[BlockedStep, CompoundStep],
5858
tune: int,
59-
model,
59+
model: Model,
6060
progressbar: bool = True,
6161
parallelize: bool = False,
6262
traces: Sequence[BaseTrace],
@@ -108,7 +108,14 @@ def _sample_population(
108108
return
109109

110110

111-
def warn_population_size(*, step: CompoundStep, initial_points, model, chains: int):
111+
def warn_population_size(
112+
*,
113+
step: Union[BlockedStep, CompoundStep],
114+
initial_points: Sequence[PointType],
115+
model: Model,
116+
chains: int,
117+
):
118+
"""Emit informative errors/warnings for dangerously small population size."""
112119
has_demcmc = np.any(
113120
[
114121
isinstance(m, DEMetropolis)

pymc/stats/log_likelihood.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Sequence
14+
from typing import Optional, Sequence, cast
1515

1616
import numpy as np
1717

@@ -22,6 +22,7 @@
2222

2323
from pymc.backends.arviz import _DefaultTrace
2424
from pymc.model import Model, modelcontext
25+
from pymc.pytensorf import PointFunc
2526
from pymc.util import dataset_to_point_list
2627

2728
__all__ = ("compute_log_likelihood",)
@@ -86,6 +87,7 @@ def compute_log_likelihood(
8687
outs=model.logp(vars=observed_vars, sum=False),
8788
on_unused_input="ignore",
8889
)
90+
elemwise_loglike_fn = cast(PointFunc, elemwise_loglike_fn)
8991
finally:
9092
model.rvs_to_values = original_rvs_to_values
9193
model.rvs_to_transforms = original_rvs_to_transforms

pymc/step_methods/compound.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self, methods):
3636
self.name = (
3737
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
3838
)
39+
self.tune = True
3940

4041
def step(self, point) -> Tuple[PointType, StatsType]:
4142
stats = []

pymc/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def enhanced(*args, **kwargs):
235235

236236

237237
def dataset_to_point_list(
238-
ds: xarray.Dataset, sample_dims: List
238+
ds: xarray.Dataset, sample_dims: Sequence[str]
239239
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
240240
# All keys of the dataset must be a str
241241
var_names = list(ds.keys())

scripts/run_mypy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
pymc/printing.py
4343
pymc/pytensorf.py
4444
pymc/sampling/jax.py
45-
pymc/stats/log_likelihood.py
4645
pymc/variational/approximations.py
4746
pymc/variational/opvi.py
4847
"""

0 commit comments

Comments
 (0)