Skip to content

Commit f478772

Browse files
authored
Raise error when using CircularReparam at observed site (#1856)
* raise error when using circular reparam at observed site * clean up
1 parent d52209c commit f478772

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

numpyro/infer/reparam.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def __call__(self, name, fn, obs):
338338
if isinstance(support, constraints.independent):
339339
support = fn.support.base_constraint
340340
assert support is constraints.circular
341+
assert obs is None, "CircularReparam does not support observe statements"
341342

342343
# Draw parameter-free noise.
343344
new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)

test/infer/test_reparam.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,16 @@ def get_actual_probe(loc, concentration):
378378
assert_allclose(actual_probe, expected_probe, atol=0.1)
379379

380380

381+
def test_circular_reparam_no_observe():
382+
def model():
383+
numpyro.sample("x", dist.VonMises(0, 1), obs=0.5)
384+
385+
with numpyro.handlers.seed(rng_seed=0):
386+
with numpyro.handlers.reparam(config={"x": CircularReparam()}):
387+
with pytest.raises(AssertionError, match="not support observe"):
388+
model()
389+
390+
381391
_unconstrain_reparam = numpyro.infer.util._unconstrain_reparam
382392

383393

0 commit comments

Comments
 (0)