@@ -257,7 +257,9 @@ def _infmean(input_array):
257
257
)
258
258
)
259
259
else :
260
- if n < 10 :
260
+ if n == 0 :
261
+ logger .info (f"Initialization only" )
262
+ elif n < 10 :
261
263
logger .info (f"Finished [100%]: Loss = { scores [- 1 ]:,.5g} " )
262
264
else :
263
265
avg_loss = _infmean (scores [max (0 , i - 1000 ) : i + 1 ])
@@ -433,8 +435,10 @@ class ADVI(KLqp):
433
435
random_seed: None or int
434
436
leave None to use package global RandomStream or other
435
437
valid value to create instance specific one
436
- start: `Point `
438
+ start: `dict[str, np.ndarray]` or `StartDict `
437
439
starting point for inference
440
+ start_sigma: `dict[str, np.ndarray]`
441
+ starting standard deviation for inference, only available for method 'advi'
438
442
439
443
References
440
444
----------
@@ -464,7 +468,7 @@ class FullRankADVI(KLqp):
464
468
random_seed: None or int
465
469
leave None to use package global RandomStream or other
466
470
valid value to create instance specific one
467
- start: `Point `
471
+ start: `dict[str, np.ndarray]` or `StartDict `
468
472
starting point for inference
469
473
470
474
References
@@ -532,13 +536,11 @@ class SVGD(ImplicitGradient):
532
536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
533
537
temperature: float
534
538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
535
- start: `dict`
539
+ start: `dict[str, np.ndarray]` or `StartDict `
536
540
initial point for inference
537
541
random_seed: None or int
538
542
leave None to use package global RandomStream or other
539
543
valid value to create instance specific one
540
- start: `Point`
541
- starting point for inference
542
544
kwargs: other keyword arguments passed to estimator
543
545
544
546
References
@@ -660,6 +662,7 @@ def fit(
660
662
model = None ,
661
663
random_seed = None ,
662
664
start = None ,
665
+ start_sigma = None ,
663
666
inf_kwargs = None ,
664
667
** kwargs ,
665
668
):
@@ -684,8 +687,10 @@ def fit(
684
687
valid value to create instance specific one
685
688
inf_kwargs: dict
686
689
additional kwargs passed to :class:`Inference`
687
- start: `Point `
690
+ start: `dict[str, np.ndarray]` or `StartDict `
688
691
starting point for inference
692
+ start_sigma: `dict[str, np.ndarray]`
693
+ starting standard deviation for inference, only available for method 'advi'
689
694
690
695
Other Parameters
691
696
----------------
@@ -728,6 +733,10 @@ def fit(
728
733
inf_kwargs ["random_seed" ] = random_seed
729
734
if start is not None :
730
735
inf_kwargs ["start" ] = start
736
+ if start_sigma is not None :
737
+ if method != "advi" :
738
+ raise NotImplementedError ("start_sigma is only available for method advi" )
739
+ inf_kwargs ["start_sigma" ] = start_sigma
731
740
if model is None :
732
741
model = pm .modelcontext (model )
733
742
_select = dict (advi = ADVI , fullrank_advi = FullRankADVI , svgd = SVGD , asvgd = ASVGD )
0 commit comments