Skip to content

Conversation

WardBrian
Copy link
Contributor

Closes #204.

Note: I believe the output-processing code in stan.rs is probably still mishandling tuples. I see it has some logic to treat complex variables as 2-arrays, which seems fine if a bit unfortunate.

@WardBrian
Copy link
Contributor Author

@aseyboldt the failures here seem to be unrelated to the Stan code -- thoughts?

@aseyboldt
Copy link
Member

Thanks!
The test failures are unrelated. In the latest PR for the normalizing flows I introduce a problem on ARM somehow, that I haven't figured out yet.

Complex values are represented as an array with two elements right now. I'll open an issue to change that to a proper complex type.

What do you mean about the tuples? This for instance seems to work as expected (apart from the complex variable)?

import nutpie

code = """
parameters {
    tuple(complex, tuple(tuple(real, real), tuple(real, real)), array[2, 3] real) xi;
}
transformed parameters {
    real re = get_real(xi.1);
    real im = get_imag(xi.1);
}
model {
    re ~ normal(0, 0.1);
    im ~ normal(-1, 0.1);

    xi.2.1.1 ~ normal(-2, 0.1);
    xi.2.1.2 ~ normal(-3, 0.1);
    xi.2.2.1 ~ normal(-4, 0.1);
    xi.2.2.2 ~ normal(-5, 0.1);

    xi.3[1, 1] ~ normal(1, 0.1);
    xi.3[1, 2] ~ normal(2, 0.1);
    xi.3[1, 3] ~ normal(3, 0.1);
    xi.3[2, 1] ~ normal(4, 0.1);
    xi.3[2, 2] ~ normal(5, 0.1);
    xi.3[2, 3] ~ normal(6, 0.1);
}
"""

compiled = nutpie.compile_stan_model(code=code)

tr = nutpie.sample(compiled)

means = tr.posterior.mean(["draw", "chain"])

image

There is no way to represent unnamed hierarchical structure in xarray, so I don't think we can do better than flattening the tuple?

@aseyboldt aseyboldt merged commit 35c2508 into pymc-devs:main May 6, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve JSON dumping in CompiledStanModel.with_data

2 participants