Skip to content

Commit b500403

Browse files
committed
Lift constraint on dim lengths being constant in R2D2M2CP
A constraint can be introduced by using `freeze_data_and_dims`
1 parent 485dc21 commit b500403

File tree

2 files changed

+1
-27
lines changed

2 files changed

+1
-27
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,7 @@ def R2D2M2CP(
418418
*broadcast_dims, dim = dims
419419
input_sigma = pt.as_tensor(input_sigma)
420420
output_sigma = pt.as_tensor(output_sigma)
421-
with pm.Model(name) as model:
422-
if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims):
423-
raise ValueError(f"{dims!r} should be constant length immutable dims")
421+
with pm.Model(name):
424422
if r2_std is not None:
425423
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
426424
phi = _phi(

pymc_experimental/tests/distributions/test_multivariate.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -302,27 +302,3 @@ def test_zero_length_rvs_not_created(self, model: pm.Model):
302302
"b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
303303
)
304304
assert not model.free_RVs, model.free_RVs
305-
306-
def test_immutable_dims(self, model: pm.Model):
307-
model.add_coord("a", range(2), mutable=True)
308-
model.add_coord("b", range(2), mutable=False)
309-
with pytest.raises(ValueError, match="should be constant length immutable dims"):
310-
pmx.distributions.R2D2M2CP(
311-
"beta0",
312-
1,
313-
[1, 1],
314-
dims="a",
315-
r2=0.8,
316-
positive_probs=[0.5, 1],
317-
positive_probs_std=[0.3, 0],
318-
)
319-
with pytest.raises(ValueError, match="should be constant length immutable dims"):
320-
pmx.distributions.R2D2M2CP(
321-
"beta0",
322-
1,
323-
[1, 1],
324-
dims=("a", "b"),
325-
r2=0.8,
326-
positive_probs=[0.5, 1],
327-
positive_probs_std=[0.3, 0],
328-
)

0 commit comments

Comments
 (0)