Skip to content

Commit 9f33248

Browse files
committed
fix: let rust sampler decide on default num chains
1 parent fcf4f7b commit 9f33248

File tree

1 file changed

+58
-27
lines changed

1 file changed

+58
-27
lines changed

python/nutpie/sample.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,42 @@ def _repr_html_(self):
455455
def sample(
456456
compiled_model: CompiledModel,
457457
*,
458-
draws: int | None,
459-
tune: int | None,
460-
chains: int,
461-
cores: Optional[int],
462-
seed: Optional[int],
463-
save_warmup: bool,
464-
progress_bar: bool,
458+
draws: int | None = None,
459+
tune: int | None = None,
460+
chains: int | None = None,
461+
cores: int | None = None,
462+
seed: int | None = None,
463+
save_warmup: bool = True,
464+
progress_bar: bool = True,
465+
low_rank_modified_mass_matrix: bool = False,
466+
transform_adapt: bool = False,
467+
init_mean: np.ndarray | None = None,
468+
return_raw_trace: bool = False,
469+
progress_template: str | None = None,
470+
progress_style: str | None = None,
471+
progress_rate: int = 100,
472+
) -> arviz.InferenceData: ...
473+
474+
475+
@overload
476+
def sample(
477+
compiled_model: CompiledModel,
478+
*,
479+
draws: int | None = None,
480+
tune: int | None = None,
481+
chains: int | None = None,
482+
cores: int | None = None,
483+
seed: int | None = None,
484+
save_warmup: bool = True,
485+
progress_bar: bool = True,
465486
low_rank_modified_mass_matrix: bool = False,
466487
transform_adapt: bool = False,
467-
init_mean: Optional[np.ndarray],
468-
return_raw_trace: bool,
488+
init_mean: np.ndarray | None = None,
489+
return_raw_trace: bool = False,
469490
blocking: Literal[True],
491+
progress_template: str | None = None,
492+
progress_style: str | None = None,
493+
progress_rate: int = 100,
470494
**kwargs,
471495
) -> arviz.InferenceData: ...
472496

@@ -475,18 +499,21 @@ def sample(
475499
def sample(
476500
compiled_model: CompiledModel,
477501
*,
478-
draws: int | None,
479-
tune: int | None,
480-
chains: int,
481-
cores: Optional[int],
482-
seed: Optional[int],
483-
save_warmup: bool,
484-
progress_bar: bool,
502+
draws: int | None = None,
503+
tune: int | None = None,
504+
chains: int | None = None,
505+
cores: int | None = None,
506+
seed: int | None = None,
507+
save_warmup: bool = True,
508+
progress_bar: bool = True,
485509
low_rank_modified_mass_matrix: bool = False,
486510
transform_adapt: bool = False,
487-
init_mean: Optional[np.ndarray],
488-
return_raw_trace: bool,
511+
init_mean: np.ndarray | None = None,
512+
return_raw_trace: bool = False,
489513
blocking: Literal[False],
514+
progress_template: str | None = None,
515+
progress_style: str | None = None,
516+
progress_rate: int = 100,
490517
**kwargs,
491518
) -> _BackgroundSampler: ...
492519

@@ -496,21 +523,21 @@ def sample(
496523
*,
497524
draws: int | None = None,
498525
tune: int | None = None,
499-
chains: int = 6,
500-
cores: Optional[int] = None,
501-
seed: Optional[int] = None,
526+
chains: int | None = None,
527+
cores: int | None = None,
528+
seed: int | None = None,
502529
save_warmup: bool = True,
503530
progress_bar: bool = True,
504531
low_rank_modified_mass_matrix: bool = False,
505532
transform_adapt: bool = False,
506-
init_mean: Optional[np.ndarray] = None,
533+
init_mean: np.ndarray | None = None,
507534
return_raw_trace: bool = False,
508535
blocking: bool = True,
509-
progress_template: Optional[str] = None,
510-
progress_style: Optional[str] = None,
536+
progress_template: str | None = None,
537+
progress_style: str | None = None,
511538
progress_rate: int = 100,
512539
**kwargs,
513-
) -> arviz.InferenceData:
540+
) -> arviz.InferenceData | _BackgroundSampler:
514541
"""Sample the posterior distribution for a compiled model.
515542
516543
Parameters
@@ -618,7 +645,8 @@ def sample(
618645
settings.num_tune = tune
619646
if draws is not None:
620647
settings.num_draws = draws
621-
settings.num_chains = chains
648+
if chains is not None:
649+
settings.num_chains = chains
622650

623651
for name, val in kwargs.items():
624652
setattr(settings, name, val)
@@ -629,7 +657,10 @@ def sample(
629657
available = os.process_cpu_count() # type: ignore
630658
except AttributeError:
631659
available = os.cpu_count()
632-
cores = min(chains, cast(int, available))
660+
if chains is None:
661+
cores = available
662+
else:
663+
cores = min(chains, cast(int, available))
633664

634665
if init_mean is None:
635666
init_mean = np.zeros(compiled_model.n_dim)

0 commit comments

Comments
 (0)