1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import pytensor .tensor as pt
1415import pytensor .xtensor as ptx
1516
1617from pymc .logprob .transforms import Transform
1718
1819
19- class LogTransform (Transform ):
20+ class DimTransform (Transform ):
21+ """Base class for transforms that are applied to dim distriubtions."""
22+
23+
24+ class LogTransform (DimTransform ):
2025 name = "log"
2126
2227 def forward (self , value , * inputs ):
@@ -32,7 +37,7 @@ def log_jac_det(self, value, *inputs):
3237log_transform = LogTransform ()
3338
3439
35- class LogOddsTransform (Transform ):
40+ class LogOddsTransform (DimTransform ):
3641 name = "logodds"
3742
3843 def backward (self , value , * inputs ):
@@ -47,3 +52,44 @@ def log_jac_det(self, value, *inputs):
4752
4853
4954log_odds_transform = LogOddsTransform ()
55+
56+
57+ class ZeroSumTransform (DimTransform ):
58+ name = "zerosum"
59+
60+ def __init__ (self , dims : tuple [str , ...]):
61+ self .dims = dims
62+
63+ @staticmethod
64+ def extend_dim (array , dim ):
65+ n = (array .sizes [dim ] + 1 ).astype ("floatX" )
66+ sum_vals = array .sum (dim )
67+ norm = sum_vals / (pt .sqrt (n ) + n )
68+ fill_val = norm - sum_vals / pt .sqrt (n )
69+
70+ out = ptx .concat ([array , fill_val ], dim = dim )
71+ return out - norm
72+
73+ @staticmethod
74+ def reduct_dim (array , dim ):
75+ n = array .sizes [dim ].astype ("floatX" )
76+ last = array .isel ({dim : - 1 })
77+
78+ sum_vals = - last * pt .sqrt (n )
79+ norm = sum_vals / (pt .sqrt (n ) + n )
80+ return array .isel ({dim : slice (None , - 1 )}) + norm
81+
82+ def forward (self , value , * rv_inputs ):
83+ for dim in self .dims :
84+ value = self .reduct_dim (value , dim = dim )
85+ return value
86+
87+ def backward (self , value , * rv_inputs ):
88+ for dim in self .dims :
89+ value = self .extend_dim (value , dim = dim )
90+ return value
91+
92+ def log_jac_det (self , value , * rv_inputs ):
93+ # Use following once broadcast_like is implemented
94+ # as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95+ return (value * 0 ).sum (self .dims )
0 commit comments