@@ -386,7 +386,7 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
386
386
387
387
Parameteres
388
388
-----------
389
- init : str {'advi', 'map', 'metropolis', ' nuts'}
389
+ init : str {'advi', 'map', 'nuts'}
390
390
Initialization method to use.
391
391
n_init : int
392
392
Number of iterations of initializer
@@ -412,20 +412,15 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
412
412
if init == 'advi' :
413
413
v_params = pm .variational .advi (n = n_init )
414
414
start = v_params .means
415
- cov = np .diagflat ( np . power (model .dict_to_array (v_params .stds ), 2 ) )
415
+ cov = np .power (model .dict_to_array (v_params .stds ), 2 )
416
416
417
417
elif init == 'map' :
418
418
start = pm .find_MAP ()
419
419
cov = pm .find_hessian (point = start )
420
420
421
- elif init == 'metropolis' :
422
- init_trace = pm .sample (step = pm .Metropolis (), draws = n_init )
423
- cov = pm .trace_cov (init_trace )
424
-
425
- start = {varname : np .mean (init_trace [varname ]) for varname in init_trace .varnames }
426
421
elif init == 'nuts' :
427
422
init_trace = pm .sample (step = pm .NUTS (), draws = n_init )
428
- cov = pm .trace_cov (init_trace )
423
+ cov = pm .trace_cov (init_trace [ n_init // 2 :] )
429
424
430
425
start = {varname : np .mean (init_trace [varname ]) for varname in init_trace .varnames }
431
426
else :
@@ -436,9 +431,6 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
436
431
step = pm .NUTS (scaling = cov , is_cov = True )
437
432
elif sampler == 'hmc' :
438
433
step = pm .HamiltonianMC (scaling = cov , is_cov = True )
439
- elif sampler == 'metropolis' :
440
- step = pm .Metropolis (scaling = cov ,
441
- proposal = pm .step_methods .metropolis .MultivariateNormalProposal )
442
434
elif sampler != 'advi' :
443
435
raise NotImplemented ('Sampler {} is not supported.' .format (init ))
444
436
0 commit comments