Skip to content

Commit 03965c1

Browse files
committed
ZeroSumNormal: Fix dims variant with batch sigma
1 parent 2dd5c04 commit 03965c1

File tree

4 files changed

+41
-13
lines changed

4 files changed

+41
-13
lines changed

pymc/dims/distributions/transforms.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,4 @@ def backward(self, value, *rv_inputs):
203203
return value
204204

205205
def log_jac_det(self, value, *rv_inputs):
206-
# Use following once broadcast_like is implemented
207-
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
208-
return value.sum(self.dims) * 0
206+
return as_xtensor(0.0).broadcast_like(value, exclude=self.dims)

pymc/dims/distributions/vector.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs):
229229
raise ValueError("ZeroSumNormal requires atleast 1 core_dims")
230230

231231
support_dims = as_xtensor(
232-
as_tensor([dim_lengths[core_dim] for core_dim in core_dims]), dims=("_",)
232+
as_tensor([dim_lengths[core_dim] for core_dim in core_dims]),
233+
dims=("__support_shape__",),
233234
)
234235
sigma = cls._as_xtensor(sigma)
235236

@@ -238,16 +239,25 @@ def dist(cls, sigma=1.0, *, core_dims=None, dim_lengths, **kwargs):
238239
)
239240

240241
@classmethod
241-
def xrv_op(cls, sigma, support_dims, core_dims, extra_dims=None, rng=None):
242-
sigma = cls._as_xtensor(sigma)
243-
support_dims = as_xtensor(support_dims, dims=("_",))
244-
support_shape = support_dims.values
245-
core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op
242+
def xrv_op(cls, sigma, support_shape, core_dims, extra_dims=None, rng=None):
243+
# ZeroSumNormal expects dummy dimensions on sigma for the support_shape
244+
sigma = cls._as_xtensor(sigma).expand_dims(core_dims)
245+
support_shape = as_xtensor(support_shape, dims=("__support_shape__",))
246+
core_rv = ZeroSumNormalRV.rv_op(
247+
sigma=sigma.values, support_shape=support_shape.values
248+
).owner.op
249+
core_dims_map = tuple(range(1, len(core_dims) + 1))
246250
xop = pxr.as_xrv(
247251
core_rv,
248-
core_inps_dims_map=[(), (0,)],
249-
core_out_dims_map=tuple(range(1, len(core_dims) + 1)),
252+
core_inps_dims_map=[core_dims_map, (0,)],
253+
core_out_dims_map=core_dims_map,
250254
)
251255
# Dummy "_" core dim to absorb the support_shape vector
252256
# If ZeroSumNormal expected a scalar per support dim, this wouldn't be needed
253-
return xop(sigma, support_dims, core_dims=("_", *core_dims), extra_dims=extra_dims, rng=rng)
257+
return xop(
258+
sigma,
259+
support_shape,
260+
core_dims=("__support_shape__", *core_dims),
261+
extra_dims=extra_dims,
262+
rng=rng,
263+
)

pymc/distributions/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def make_node(self, *inputs):
392392
)
393393
if size_arg_idx is not None and len(rng_arg_idxs) == 1:
394394
new_size_type = normalize_size_param(inputs[size_arg_idx]).type
395-
if not self.input_types[size_arg_idx].in_same_class(new_size_type):
395+
if not self.input_types[size_arg_idx].is_super(new_size_type):
396396
params = [inputs[idx] for idx in param_idxs]
397397
size = inputs[size_arg_idx]
398398
rng = inputs[rng_arg_idxs[0]]

tests/dims/distributions/test_vector.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,23 @@ def test_zerosumnormal():
9999
# Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same
100100
# Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed
101101
# assert_equivalent_logp_graph(model, reference_model)
102+
103+
104+
def test_zerosumnormal_batch_sigma():
105+
coords = {"a": range(3), "b": range(5)}
106+
sigma = np.array([1, 2, 3.0])
107+
with Model(coords=coords) as model:
108+
ZeroSumNormal(
109+
"x",
110+
sigma=as_xtensor(sigma, dims=("a",)),
111+
core_dims=("b",),
112+
)
113+
114+
with Model(coords=coords) as ref_model:
115+
regular_distributions.ZeroSumNormal("x", sigma=sigma[:, None], dims=("a", "b"))
116+
117+
ip = model.initial_point()
118+
np.testing.assert_allclose(
119+
model.compile_logp(sum=False)(ip),
120+
ref_model.compile_logp(sum=False)(ip),
121+
)

0 commit comments

Comments
 (0)