Skip to content

Commit 98ccc68

Browse files
committed
Implement power transforms
1 parent 90b6bec commit 98ccc68

File tree

2 files changed

+161
-20
lines changed

2 files changed

+161
-20
lines changed

pymc/logprob/transforms.py

Lines changed: 130 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from copy import copy
4040
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
4141

42+
import numpy as np
4243
import pytensor.tensor as at
4344

4445
from pytensor.gradient import DisconnectedType, jacobian
@@ -48,10 +49,22 @@
4849
from pytensor.graph.op import Op
4950
from pytensor.graph.replace import clone_replace
5051
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
51-
from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
52+
from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt
5253
from pytensor.scan.op import Scan
5354
from pytensor.tensor.exceptions import NotScalarConstantError
54-
from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div
55+
from pytensor.tensor.math import (
56+
add,
57+
exp,
58+
log,
59+
mul,
60+
neg,
61+
pow,
62+
reciprocal,
63+
sqr,
64+
sqrt,
65+
sub,
66+
true_div,
67+
)
5568
from pytensor.tensor.rewriting.basic import (
5669
register_specialize,
5770
register_stabilize,
@@ -110,8 +123,11 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
110123
"""Apply the transformation."""
111124

112125
@abc.abstractmethod
113-
def backward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
114-
"""Invert the transformation."""
126+
def backward(
127+
self, value: TensorVariable, *inputs: Variable
128+
) -> Union[TensorVariable, Tuple[TensorVariable, ...]]:
129+
"""Invert the transformation. Multiple values may be returned when the
130+
transformation is not 1-to-1"""
115131

116132
def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
117133
"""Construct the log of the absolute value of the Jacobian determinant."""
@@ -320,7 +336,7 @@ def apply(self, fgraph: FunctionGraph):
320336
class MeasurableTransform(MeasurableElemwise):
321337
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
322338

323-
valid_scalar_types = (Exp, Log, Add, Mul, Reciprocal)
339+
valid_scalar_types = (Exp, Log, Add, Mul, Pow)
324340

325341
# Cannot use `transform` as name because it would clash with the property added by
326342
# the `TransformValuesRewrite`
@@ -349,16 +365,64 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
349365
# The value variable must still be back-transformed to be on the natural support of
350366
# the respective measurable input.
351367
backward_value = op.transform_elemwise.backward(value, *other_inputs)
352-
input_logprob = logprob(measurable_input, backward_value, **kwargs)
368+
369+
# Some transformations, like squaring may produce multiple backward values
370+
if isinstance(backward_value, tuple):
371+
input_logprob = at.logaddexp(
372+
*(logprob(measurable_input, backward_val, **kwargs) for backward_val in backward_value)
373+
)
374+
else:
375+
input_logprob = logprob(measurable_input, backward_value)
353376

354377
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
355378

356379
return input_logprob + jacobian
357380

358381

382+
@node_rewriter([reciprocal])
383+
def measurable_reciprocal_to_power(fgraph, node):
384+
"""Convert reciprocal of `MeasurableVariable`s to power."""
385+
inp = node.inputs[0]
386+
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
387+
return None
388+
389+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
390+
if rv_map_feature is None:
391+
return None # pragma: no cover
392+
393+
# Only apply this rewrite if the variable is unvalued
394+
if inp in rv_map_feature.rv_values:
395+
return None # pragma: no cover
396+
397+
return [at.pow(inp, -1.0)]
398+
399+
400+
@node_rewriter([sqr, sqrt])
401+
def measurable_sqrt_sqr_to_power(fgraph, node):
402+
"""Convert square root or square of `MeasurableVariable`s to power form."""
403+
404+
inp = node.inputs[0]
405+
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
406+
return None
407+
408+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
409+
if rv_map_feature is None:
410+
return None # pragma: no cover
411+
412+
# Only apply this rewrite if the variable is unvalued
413+
if inp in rv_map_feature.rv_values:
414+
return None # pragma: no cover
415+
416+
if isinstance(node.op.scalar_op, Sqr):
417+
return [at.pow(inp, 2)]
418+
419+
if isinstance(node.op.scalar_op, Sqrt):
420+
return [at.pow(inp, 1 / 2)]
421+
422+
359423
@node_rewriter([true_div])
360-
def measurable_div_to_reciprocal_product(fgraph, node):
361-
"""Convert divisions involving `MeasurableVariable`s to product with reciprocal."""
424+
def measurable_div_to_product(fgraph, node):
425+
"""Convert divisions involving `MeasurableVariable`s to products."""
362426

363427
measurable_vars = [
364428
var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable))
@@ -379,9 +443,13 @@ def measurable_div_to_reciprocal_product(fgraph, node):
379443
# Check if numerator is 1
380444
try:
381445
if at.get_scalar_constant_value(numerator) == 1:
382-
return [at.reciprocal(denominator)]
446+
# We convert the denominator directly to a power transform as this
447+
# must be the measurable input
448+
return [at.pow(denominator, -1)]
383449
except NotScalarConstantError:
384450
pass
451+
# We don't convert the denominator directly to a power transform as
452+
# it might not be measurable (and therefore not needed)
385453
return [at.mul(numerator, at.reciprocal(denominator))]
386454

387455

@@ -425,7 +493,7 @@ def measurable_sub_to_neg(fgraph, node):
425493
return [at.add(minuend, at.neg(subtrahend))]
426494

427495

428-
@node_rewriter([exp, log, add, mul, reciprocal])
496+
@node_rewriter([exp, log, add, mul, pow])
429497
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
430498
"""Find measurable transformations from Elemwise operators."""
431499

@@ -485,8 +553,18 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
485553
transform = ExpTransform()
486554
elif isinstance(scalar_op, Log):
487555
transform = LogTransform()
488-
elif isinstance(scalar_op, Reciprocal):
489-
transform = ReciprocalTransform()
556+
elif isinstance(scalar_op, Pow):
557+
# We only allow for the base to be measurable
558+
if measurable_input_idx != 0:
559+
return None
560+
try:
561+
(power,) = other_inputs
562+
power = at.get_scalar_constant_value(power).item()
563+
# Power needs to be a constant
564+
except NotScalarConstantError:
565+
return None
566+
transform_inputs = (measurable_input, power)
567+
transform = PowerTransform(power=power)
490568
elif isinstance(scalar_op, Add):
491569
transform_inputs = (measurable_input, at.add(*other_inputs))
492570
transform = LocTransform(
@@ -510,12 +588,29 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
510588

511589

512590
measurable_ir_rewrites_db.register(
513-
"measurable_div_to_reciprocal_product",
514-
measurable_div_to_reciprocal_product,
591+
"measurable_reciprocal_to_power",
592+
measurable_reciprocal_to_power,
515593
"basic",
516594
"transform",
517595
)
518596

597+
598+
measurable_ir_rewrites_db.register(
599+
"measurable_sqrt_sqr_to_power",
600+
measurable_sqrt_sqr_to_power,
601+
"basic",
602+
"transform",
603+
)
604+
605+
606+
measurable_ir_rewrites_db.register(
607+
"measurable_div_to_product",
608+
measurable_div_to_product,
609+
"basic",
610+
"transform",
611+
)
612+
613+
519614
measurable_ir_rewrites_db.register(
520615
"measurable_neg_to_product",
521616
measurable_neg_to_product,
@@ -601,17 +696,33 @@ def log_jac_det(self, value, *inputs):
601696
return -at.log(value)
602697

603698

604-
class ReciprocalTransform(RVTransform):
605-
name = "reciprocal"
699+
class PowerTransform(RVTransform):
700+
name = "power"
701+
702+
def __init__(self, power=None):
703+
if not isinstance(power, (int, float)):
704+
raise TypeError(f"Power must be integer or float, got {type(power)}")
705+
if power == 0:
706+
raise ValueError("Power cannot be 0")
707+
self.power = power
708+
super().__init__()
606709

607710
def forward(self, value, *inputs):
608-
return at.reciprocal(value)
711+
at.power(value, self.power)
609712

610713
def backward(self, value, *inputs):
611-
return at.reciprocal(value)
714+
backward_value = at.power(value, (1 / self.power))
715+
716+
# In this case the transform is not 1-to-1
717+
if (self.power > 1) and (self.power % 2 == 0):
718+
return -backward_value, backward_value
719+
else:
720+
return backward_value
612721

613722
def log_jac_det(self, value, *inputs):
614-
return -2 * at.log(value)
723+
inv_power = 1 / self.power
724+
# Note: This fails for value==0
725+
return np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(value)
615726

616727

617728
class IntervalTransform(RVTransform):

pymc/tests/logprob/test_transforms.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
224224

225225
a = at_dist(*dist_params, size=size)
226226
a.name = "a"
227-
a_value_var = at.tensor(a.dtype, shape=(None,) * a.ndim)
227+
a_value_var = at.tensor(dtype=a.dtype, shape=(None,) * a.ndim)
228228
a_value_var.name = "a_value"
229229

230230
b = at.random.normal(a, 1.0)
@@ -807,6 +807,36 @@ def test_reciprocal_rv_transform(numerator):
807807
)
808808

809809

810+
def test_sqr_transform():
811+
# The square of a unit normal is a chi-square with 1 df
812+
x_rv = at.random.normal(0, 1, size=(3,)) ** 2
813+
x_rv.name = "x"
814+
815+
x_vv = x_rv.clone()
816+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
817+
818+
x_test_val = np.r_[0.5, 1, 2.5]
819+
assert np.allclose(
820+
x_logp_fn(x_test_val),
821+
sp.stats.chi2(df=1).logpdf(x_test_val),
822+
)
823+
824+
825+
def test_sqrt_transform():
826+
# The sqrt of a chisquare with n df is a chi distribution with n df
827+
x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,)))
828+
x_rv.name = "x"
829+
830+
x_vv = x_rv.clone()
831+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
832+
833+
x_test_val = np.r_[0.5, 1, 2.5]
834+
assert np.allclose(
835+
x_logp_fn(x_test_val),
836+
sp.stats.chi(df=3).logpdf(x_test_val),
837+
)
838+
839+
810840
def test_negated_rv_transform():
811841
x_rv = -at.random.halfnormal()
812842
x_rv.name = "x"

0 commit comments

Comments
 (0)