Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ jobs:
tests/backends/test_zarr.py
tests/distributions/test_truncated.py
tests/logprob/test_abstract.py
tests/logprob/test_arithmetic.py
tests/logprob/test_basic.py
tests/logprob/test_binary.py
tests/logprob/test_checks.py
Expand Down
2 changes: 2 additions & 0 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,8 @@ def init_group_with_empty(
for i, shape_i in enumerate(shape):
dim = f"{name}_dim_{i}"
dims.append(dim)
if shape_i is None: # failing mypy here, shape_i is definitely an int
continue
group_coords[dim] = np.arange(shape_i, dtype="int")
dims = ("chain", "draw", *dims)
attrs = extra_var_attrs[name] if extra_var_attrs is not None else {}
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# Add rewrites to the DBs
import pymc.logprob.binary
import pymc.logprob.censoring
import pymc.logprob.arithmetic
import pymc.logprob.cumsum
import pymc.logprob.checks
import pymc.logprob.linalg
Expand Down
86 changes: 86 additions & 0 deletions pymc/logprob/arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Measurable rewrites for arithmetic operations."""

from pytensor import tensor as pt
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.extra_ops import broadcast_shape
from pytensor.tensor.math import Sum
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.rewriting import measurable_ir_rewrites_db


@node_rewriter([Sum])
def sum_of_normals(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
[base_var] = node.inputs
if base_var.owner is None:
return None

latent_op = base_var.owner.op
if not isinstance(latent_op, NormalRV):
return None

mu, sigma = latent_op.dist_params(base_var.owner)

size = latent_op.size_param(base_var.owner)
if size is None or isinstance(size.type, NoneTypeT):
target_shape = broadcast_shape(mu, sigma) # type: ignore[arg-type]
else:
target_shape = size # type: ignore[assignment]

mu_b = pt.broadcast_to(mu, target_shape) # type: ignore[arg-type]
sigma_b = pt.broadcast_to(sigma, target_shape) # type: ignore[arg-type]

axis = node.op.axis
mu_sum = pt.sum(mu_b, axis=axis)
sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b), axis=axis))

rng = base_var.owner.inputs[0]
sum_rv = latent_op.make_node(rng, NoneConst, mu_sum, sigma_sum).outputs[1]
return [sum_rv]


measurable_ir_rewrites_db.register(
"sum_of_normals",
sum_of_normals,
"basic",
"arithmetic",
)
64 changes: 64 additions & 0 deletions tests/logprob/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np

from pytensor import tensor as pt

from pymc.logprob.basic import logp


def test_sum_of_normals_logprob():
mu = pt.constant([1.0, 2.0, 3.0])
sigma = pt.constant([1.0, 2.0, 3.0])

x_rv = pt.random.normal(mu, sigma, name="x")
x_sum = pt.sum(x_rv)
x_sum_vv = pt.scalar("x_sum")

sum_logp = logp(x_sum, x_sum_vv)

ref_mu = pt.sum(mu)
ref_sigma = pt.sqrt(pt.sum(pt.square(sigma)))
ref_rv = pt.random.normal(ref_mu, ref_sigma, name="ref")
ref_vv = pt.scalar("ref_vv")
ref_logp = logp(ref_rv, ref_vv)

test_val = 0.5
np.testing.assert_allclose(
sum_logp.eval({x_sum_vv: test_val}),
ref_logp.eval({ref_vv: test_val}),
)
19 changes: 19 additions & 0 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def test_observe_deterministic():
pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs)


def test_observe_sum_normal():
with pm.Model() as m_old:
x = pm.Normal("x")
y = pm.Normal.dist(mu=x, sigma=1.0, shape=(5,))
y_sum = pm.Deterministic("y_sum", pm.math.sum(y))

m_new = observe(m_old, {y_sum: 2.0})

with pm.Model() as m_ref:
x = pm.Normal("x")
pm.Normal("y_sum", mu=5.0 * x, sigma=np.sqrt(5.0), observed=2.0)

test_point = {"x": 0.3}
np.testing.assert_allclose(
m_new.compile_logp()(test_point),
m_ref.compile_logp()(test_point),
)


def test_observe_dims():
with pm.Model(coords={"test_dim": range(5)}) as m_old:
x = pm.Normal("x", dims="test_dim")
Expand Down