Skip to content

Commit 777cdd3

Browse files
committed
WIP: Subclass DimTransform to avoid bugs
1 parent 56402aa commit 777cdd3

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

pymc/dims/transforms.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytensor.tensor as pt
1415
import pytensor.xtensor as ptx
1516

1617
from pymc.logprob.transforms import Transform
1718

1819

19-
class LogTransform(Transform):
20+
class DimTransform(Transform):
21+
"""Base class for transforms that are applied to dim distriubtions."""
22+
23+
24+
class LogTransform(DimTransform):
2025
name = "log"
2126

2227
def forward(self, value, *inputs):
@@ -32,7 +37,7 @@ def log_jac_det(self, value, *inputs):
3237
log_transform = LogTransform()
3338

3439

35-
class LogOddsTransform(Transform):
40+
class LogOddsTransform(DimTransform):
3641
name = "logodds"
3742

3843
def backward(self, value, *inputs):
@@ -47,3 +52,44 @@ def log_jac_det(self, value, *inputs):
4752

4853

4954
log_odds_transform = LogOddsTransform()
55+
56+
57+
class ZeroSumTransform(DimTransform):
58+
name = "zerosum"
59+
60+
def __init__(self, dims: tuple[str, ...]):
61+
self.dims = dims
62+
63+
@staticmethod
64+
def extend_dim(array, dim):
65+
n = (array.sizes[dim] + 1).astype("floatX")
66+
sum_vals = array.sum(dim)
67+
norm = sum_vals / (pt.sqrt(n) + n)
68+
fill_val = norm - sum_vals / pt.sqrt(n)
69+
70+
out = ptx.concat([array, fill_val], dim=dim)
71+
return out - norm
72+
73+
@staticmethod
74+
def reduct_dim(array, dim):
75+
n = array.sizes[dim].astype("floatX")
76+
last = array.isel({dim: -1})
77+
78+
sum_vals = -last * pt.sqrt(n)
79+
norm = sum_vals / (pt.sqrt(n) + n)
80+
return array.isel({dim: slice(None, -1)}) + norm
81+
82+
def forward(self, value, *rv_inputs):
83+
for dim in self.dims:
84+
value = self.reduct_dim(value, dim=dim)
85+
return value
86+
87+
def backward(self, value, *rv_inputs):
88+
for dim in self.dims:
89+
value = self.extend_dim(value, dim=dim)
90+
return value
91+
92+
def log_jac_det(self, value, *rv_inputs):
93+
# Use following once broadcast_like is implemented
94+
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95+
return (value * 0).sum(self.dims)

0 commit comments

Comments
 (0)