@@ -455,18 +455,42 @@ def _repr_html_(self):
455
455
def sample (
456
456
compiled_model : CompiledModel ,
457
457
* ,
458
- draws : int | None ,
459
- tune : int | None ,
460
- chains : int ,
461
- cores : Optional [int ],
462
- seed : Optional [int ],
463
- save_warmup : bool ,
464
- progress_bar : bool ,
458
+ draws : int | None = None ,
459
+ tune : int | None = None ,
460
+ chains : int | None = None ,
461
+ cores : int | None = None ,
462
+ seed : int | None = None ,
463
+ save_warmup : bool = True ,
464
+ progress_bar : bool = True ,
465
+ low_rank_modified_mass_matrix : bool = False ,
466
+ transform_adapt : bool = False ,
467
+ init_mean : np .ndarray | None = None ,
468
+ return_raw_trace : bool = False ,
469
+ progress_template : str | None = None ,
470
+ progress_style : str | None = None ,
471
+ progress_rate : int = 100 ,
472
+ ) -> arviz .InferenceData : ...
473
+
474
+
475
+ @overload
476
+ def sample (
477
+ compiled_model : CompiledModel ,
478
+ * ,
479
+ draws : int | None = None ,
480
+ tune : int | None = None ,
481
+ chains : int | None = None ,
482
+ cores : int | None = None ,
483
+ seed : int | None = None ,
484
+ save_warmup : bool = True ,
485
+ progress_bar : bool = True ,
465
486
low_rank_modified_mass_matrix : bool = False ,
466
487
transform_adapt : bool = False ,
467
- init_mean : Optional [ np .ndarray ] ,
468
- return_raw_trace : bool ,
488
+ init_mean : np .ndarray | None = None ,
489
+ return_raw_trace : bool = False ,
469
490
blocking : Literal [True ],
491
+ progress_template : str | None = None ,
492
+ progress_style : str | None = None ,
493
+ progress_rate : int = 100 ,
470
494
** kwargs ,
471
495
) -> arviz .InferenceData : ...
472
496
@@ -475,18 +499,21 @@ def sample(
475
499
def sample (
476
500
compiled_model : CompiledModel ,
477
501
* ,
478
- draws : int | None ,
479
- tune : int | None ,
480
- chains : int ,
481
- cores : Optional [ int ] ,
482
- seed : Optional [ int ] ,
483
- save_warmup : bool ,
484
- progress_bar : bool ,
502
+ draws : int | None = None ,
503
+ tune : int | None = None ,
504
+ chains : int | None = None ,
505
+ cores : int | None = None ,
506
+ seed : int | None = None ,
507
+ save_warmup : bool = True ,
508
+ progress_bar : bool = True ,
485
509
low_rank_modified_mass_matrix : bool = False ,
486
510
transform_adapt : bool = False ,
487
- init_mean : Optional [ np .ndarray ] ,
488
- return_raw_trace : bool ,
511
+ init_mean : np .ndarray | None = None ,
512
+ return_raw_trace : bool = False ,
489
513
blocking : Literal [False ],
514
+ progress_template : str | None = None ,
515
+ progress_style : str | None = None ,
516
+ progress_rate : int = 100 ,
490
517
** kwargs ,
491
518
) -> _BackgroundSampler : ...
492
519
@@ -496,21 +523,21 @@ def sample(
496
523
* ,
497
524
draws : int | None = None ,
498
525
tune : int | None = None ,
499
- chains : int = 6 ,
500
- cores : Optional [ int ] = None ,
501
- seed : Optional [ int ] = None ,
526
+ chains : int | None = None ,
527
+ cores : int | None = None ,
528
+ seed : int | None = None ,
502
529
save_warmup : bool = True ,
503
530
progress_bar : bool = True ,
504
531
low_rank_modified_mass_matrix : bool = False ,
505
532
transform_adapt : bool = False ,
506
- init_mean : Optional [ np .ndarray ] = None ,
533
+ init_mean : np .ndarray | None = None ,
507
534
return_raw_trace : bool = False ,
508
535
blocking : bool = True ,
509
- progress_template : Optional [ str ] = None ,
510
- progress_style : Optional [ str ] = None ,
536
+ progress_template : str | None = None ,
537
+ progress_style : str | None = None ,
511
538
progress_rate : int = 100 ,
512
539
** kwargs ,
513
- ) -> arviz .InferenceData :
540
+ ) -> arviz .InferenceData | _BackgroundSampler :
514
541
"""Sample the posterior distribution for a compiled model.
515
542
516
543
Parameters
@@ -618,7 +645,8 @@ def sample(
618
645
settings .num_tune = tune
619
646
if draws is not None :
620
647
settings .num_draws = draws
621
- settings .num_chains = chains
648
+ if chains is not None :
649
+ settings .num_chains = chains
622
650
623
651
for name , val in kwargs .items ():
624
652
setattr (settings , name , val )
@@ -629,7 +657,10 @@ def sample(
629
657
available = os .process_cpu_count () # type: ignore
630
658
except AttributeError :
631
659
available = os .cpu_count ()
632
- cores = min (chains , cast (int , available ))
660
+ if chains is None :
661
+ cores = available
662
+ else :
663
+ cores = min (chains , cast (int , available ))
633
664
634
665
if init_mean is None :
635
666
init_mean = np .zeros (compiled_model .n_dim )
0 commit comments