Skip to content

Commit 2953c8b

Browse files
ricardoV94Luke LB
andcommitted
Generalize handling of invalid values in measurable transforms
We rely on the jacobian returning `nan` for invalid values Co-authored-by: Luke LB <[email protected]>
1 parent 029d548 commit 2953c8b

File tree

2 files changed

+72
-14
lines changed

2 files changed

+72
-14
lines changed

pymc/logprob/transforms.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,14 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
374374
else:
375375
input_logprob = logprob(measurable_input, backward_value)
376376

377+
if input_logprob.ndim < value.ndim:
378+
# Do we just need to sum the jacobian terms across the support dims?
379+
raise NotImplementedError("Transform of multivariate RVs not implemented")
380+
377381
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
378382

379-
return input_logprob + jacobian
383+
# The jacobian is used to ensure a value in the supported domain was provided
384+
return at.switch(at.isnan(jacobian), -np.inf, input_logprob + jacobian)
380385

381386

382387
@node_rewriter([reciprocal])
@@ -711,18 +716,32 @@ def forward(self, value, *inputs):
711716
at.power(value, self.power)
712717

713718
def backward(self, value, *inputs):
714-
backward_value = at.power(value, (1 / self.power))
719+
inv_power = 1 / self.power
720+
721+
# Powers that don't admit negative values
722+
if (np.abs(self.power) < 1) or (self.power % 2 == 0):
723+
backward_value = at.switch(value >= 0, at.power(value, inv_power), np.nan)
724+
# Powers that admit negative values require special logic, because (-1)**(1/3) returns `nan` in PyTensor
725+
else:
726+
backward_value = at.power(at.abs(value), inv_power) * at.switch(value >= 0, 1, -1)
715727

716728
# In this case the transform is not 1-to-1
717-
if (self.power > 1) and (self.power % 2 == 0):
729+
if self.power % 2 == 0:
718730
return -backward_value, backward_value
719731
else:
720732
return backward_value
721733

722734
def log_jac_det(self, value, *inputs):
723735
inv_power = 1 / self.power
736+
724737
# Note: This fails for value==0
725-
return np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(value)
738+
res = np.log(np.abs(inv_power)) + (inv_power - 1) * at.log(at.abs(value))
739+
740+
# Powers that don't admit negative values
741+
if (np.abs(self.power) < 1) or (self.power % 2 == 0):
742+
res = at.switch(value >= 0, res, np.nan)
743+
744+
return res
726745

727746

728747
class IntervalTransform(RVTransform):

pymc/tests/logprob/test_transforms.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -632,15 +632,15 @@ def test_chained_transform():
632632

633633

634634
def test_exp_transform_rv():
635-
base_rv = at.random.normal(0, 1, size=2, name="base_rv")
635+
base_rv = at.random.normal(0, 1, size=3, name="base_rv")
636636
y_rv = at.exp(base_rv)
637637
y_rv.name = "y"
638638

639639
y_vv = y_rv.clone()
640640
logp = joint_logprob({y_rv: y_vv}, sum=False)
641641
logp_fn = pytensor.function([y_vv], logp)
642642

643-
y_val = [0.1, 0.3]
643+
y_val = [-2.0, 0.1, 0.3]
644644
np.testing.assert_allclose(
645645
logp_fn(y_val),
646646
sp.stats.lognorm(s=1).logpdf(y_val),
@@ -794,28 +794,28 @@ def test_invalid_broadcasted_transform_rv_fails():
794794
def test_reciprocal_rv_transform(numerator):
795795
shape = 3
796796
scale = 5
797-
x_rv = numerator / at.random.gamma(shape, scale)
797+
x_rv = numerator / at.random.gamma(shape, scale, size=(2,))
798798
x_rv.name = "x"
799799

800800
x_vv = x_rv.clone()
801-
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))
801+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
802802

803-
x_test_val = 1.5
804-
assert np.isclose(
803+
x_test_val = np.r_[-0.5, 1.5]
804+
assert np.allclose(
805805
x_logp_fn(x_test_val),
806806
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
807807
)
808808

809809

810810
def test_sqr_transform():
811811
# The square of a unit normal is a chi-square with 1 df
812-
x_rv = at.random.normal(0, 1, size=(3,)) ** 2
812+
x_rv = at.random.normal(0, 1, size=(4,)) ** 2
813813
x_rv.name = "x"
814814

815815
x_vv = x_rv.clone()
816816
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
817817

818-
x_test_val = np.r_[0.5, 1, 2.5]
818+
x_test_val = np.r_[-0.5, 0.5, 1, 2.5]
819819
assert np.allclose(
820820
x_logp_fn(x_test_val),
821821
sp.stats.chi2(df=1).logpdf(x_test_val),
@@ -824,19 +824,58 @@ def test_sqr_transform():
824824

825825
def test_sqrt_transform():
826826
# The sqrt of a chisquare with n df is a chi distribution with n df
827-
x_rv = at.sqrt(at.random.chisquare(df=3, size=(3,)))
827+
x_rv = at.sqrt(at.random.chisquare(df=3, size=(4,)))
828828
x_rv.name = "x"
829829

830830
x_vv = x_rv.clone()
831831
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
832832

833-
x_test_val = np.r_[0.5, 1, 2.5]
833+
x_test_val = np.r_[-2.5, 0.5, 1, 2.5]
834834
assert np.allclose(
835835
x_logp_fn(x_test_val),
836836
sp.stats.chi(df=3).logpdf(x_test_val),
837837
)
838838

839839

840+
@pytest.mark.parametrize("power", (-3, -1, 1, 5, 7))
841+
def test_negative_value_odd_power_transform(power):
842+
# check that negative values and odd powers evaluate to a finite logp
843+
x_rv = at.random.normal() ** power
844+
x_rv.name = "x"
845+
846+
x_vv = x_rv.clone()
847+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
848+
849+
assert np.isfinite(x_logp_fn(1))
850+
assert np.isfinite(x_logp_fn(-1))
851+
852+
853+
@pytest.mark.parametrize("power", (-2, 2, 4, 6, 8))
854+
def test_negative_value_even_power_transform(power):
855+
# check that negative values and odd powers evaluate to -inf logp
856+
x_rv = at.random.normal() ** power
857+
x_rv.name = "x"
858+
859+
x_vv = x_rv.clone()
860+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
861+
862+
assert np.isfinite(x_logp_fn(1))
863+
assert np.isneginf(x_logp_fn(-1))
864+
865+
866+
@pytest.mark.parametrize("power", (-1 / 3, -1 / 2, 1 / 2, 1 / 3))
867+
def test_negative_value_frac_power_transform(power):
868+
# check that negative values and fractional powers evaluate to -inf logp
869+
x_rv = at.random.normal() ** power
870+
x_rv.name = "x"
871+
872+
x_vv = x_rv.clone()
873+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
874+
875+
assert np.isfinite(x_logp_fn(2.5))
876+
assert np.isneginf(x_logp_fn(-2.5))
877+
878+
840879
def test_negated_rv_transform():
841880
x_rv = -at.random.halfnormal()
842881
x_rv.name = "x"

0 commit comments

Comments
 (0)