Skip to content

Turing and AdvancedHMC give different adaptors for NUTS #2717

@ysfoo

Description

@ysfoo

Minimal working example

using Turing, AdvancedHMC, AbstractMCMC, Random

@model function gdemo(x, y)
    m ~ Normal(0, sqrt(1))
    x ~ Normal(m, sqrt(1))
    y ~ Normal(m, sqrt(1))
end

turing_chain = sample(gdemo(1.5, 2), Turing.NUTS(0.8), 1000; save_state=true, n_adapts=500);
display(turing_chain.info.samplerstate.adaptor)

ahmc_transition, ahmc_state = AbstractMCMC.step(
    Random.default_rng(),
    AbstractMCMC._model(LogDensityFunction(gdemo(1.5, 2))),
    AdvancedHMC.NUTS(0.8); n_adapts = 500
);
display(ahmc_state.adaptor)

Description

Running NUTS with Turing and AdvancedHMC both use adaptors of StanHMCAdaptor, but the fields are different. In particular, the Turing version has a negative value for the endpoint of the mass matrix adaptation window. Consequently, Turing does not perform mass matrix adapatation.

Output of MWE:

 Info: Found initial step size
└   ϵ = 1.6
Sampling 100%|████████████████████████████████████████| Time: 0:00:07
StanHMCAdaptor(
    pc=WelfordVar{Float64} adaptor,
    ssa=NesterovDualAveraging(0.05, 10.0, 0.75, 0.8, 0.7804915138075489),
    init_buffer=75, term_buffer=50, window_size=25,
    state=window(76, -51), window_splits()
)
[ Info: Found initial step size 3.2
StanHMCAdaptor(
    pc=WelfordVar{Float64} adaptor,
    ssa=NesterovDualAveraging(0.05, 10.0, 0.75, 0.8, 7.472207330909228),
    init_buffer=75, term_buffer=50, window_size=25,
    state=window(76, 450), window_splits(100, 150, 250, 450)
)

Julia version info

Julia Version 1.11.6

Manifest

  [80f14c24] AbstractMCMC v5.10.0
  [0bf59076] AdvancedHMC v0.8.3
  [fce5fe82] Turing v0.41.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions