diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8b5eac7b16..018d94940e 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -35,7 +35,7 @@ # SOFTWARE. import abc -from collections.abc import Callable +from collections.abc import Callable, Sequence import numpy as np import pytensor.tensor as pt @@ -141,6 +141,10 @@ def backward( Multiple values may be returned when the transformation is not 1-to-1. """ + @abc.abstractmethod + def transform_coords(self, coords: Sequence[str]) -> Sequence[str]: + """Mutate user-provided coordinates associated with the variable to label transformed values returned by this class.""" + def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: """Construct the log of the absolute value of the Jacobian determinant.""" if self.ndim_supp not in (0, 1): @@ -614,6 +618,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arcsinh(value) + def transform_coords(self, coords): + return coords + class CoshTransform(Transform): name = "cosh" @@ -633,6 +640,9 @@ def log_jac_det(self, value, *inputs): -pt.log(pt.sqrt(value**2 - 1)), ) + def transform_coords(self, coords): + return coords + class TanhTransform(Transform): name = "tanh" @@ -644,6 +654,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.arctanh(value) + def transform_coords(self, coords): + return coords + class ArcsinhTransform(Transform): name = "arcsinh" @@ -655,6 +668,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.sinh(value) + def transform_coords(self, coords): + return coords + class ArccoshTransform(Transform): name = "arccosh" @@ -666,6 +682,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.cosh(value) + def transform_coords(self, coords): + return coords + class ArctanhTransform(Transform): name = "arctanh" @@ -677,6 +696,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.tanh(value) + def transform_coords(self, coords): + return coords + class ErfTransform(Transform): name = "erf" @@ -688,6 +710,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.erfinv(value) + def transform_coords(self, coords): + return coords + class ErfcTransform(Transform): name = "erfc" @@ -699,6 +724,9 @@ def forward(self, value, *inputs): def backward(self, value, *inputs): return pt.erfcinv(value) + def transform_coords(self, coords): + return coords + class ErfcxTransform(Transform): name = "erfcx" @@ -725,6 +753,9 @@ def calc_delta_x(value, prior_result): ) return result[-1] + def transform_coords(self, coords): + return coords + class LocTransform(Transform): name = "loc" @@ -743,6 +774,9 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): return pt.zeros_like(value) + def transform_coords(self, coords): + return coords + class ScaleTransform(Transform): name = "scale" @@ -762,6 +796,9 @@ def log_jac_det(self, value, *inputs): scale = self.transform_args_fn(*inputs) return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) + def transform_coords(self, coords): + return coords + class LogTransform(Transform): name = "log" @@ -775,6 +812,9 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): return value + def transform_coords(self, coords): + return coords + class ExpTransform(Transform): name = "exp" @@ -788,6 +828,9 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): return -pt.log(value) + def transform_coords(self, coords): + return coords + class AbsTransform(Transform): name = "abs" @@ -802,6 +845,9 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): return pt.switch(value >= 0, 0, np.nan) + def transform_coords(self, coords): + return coords + class PowerTransform(Transform): name = "power" @@ -845,6 +891,9 @@ def log_jac_det(self, value, *inputs): return res + def transform_coords(self, coords): + return coords + class IntervalTransform(Transform): name = "interval" @@ -953,6 +1002,9 @@ def log_jac_det(self, value, *inputs): else: return pt.zeros_like(value) + def transform_coords(self, coords): + return coords + class LogOddsTransform(Transform): name = "logodds" @@ -967,6 +1019,9 @@ def log_jac_det(self, value, *inputs): sigmoid_value = pt.sigmoid(value) return pt.log(sigmoid_value) + pt.log1p(-sigmoid_value) + def transform_coords(self, coords): + return coords + class SimplexTransform(Transform): name = "simplex" @@ -994,6 +1049,11 @@ def log_jac_det(self, value, *inputs): res = pt.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) return pt.sum(res, -1) + def transform_coords(self, coords): + if len(coords) == 0: + return coords + return coords[:-1] + class CircularTransform(Transform): name = "circular" @@ -1007,6 +1067,9 @@ def forward(self, value, *inputs): def log_jac_det(self, value, *inputs): return pt.zeros_like(value) + def transform_coords(self, coords): + return coords + class ChainedTransform(Transform): name = "chain" @@ -1042,3 +1105,8 @@ def log_jac_det(self, value, *inputs): else: det += det_ return det + + def transform_coords(self, coords): + for transform in self.transform_list: + coords = transform.transform_coords(coords) + return coords