File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -300,7 +300,7 @@ def __init__(self, zerosum_axes):
300300
301301 @staticmethod
302302 def extend_axis (array , axis ):
303- n = (array .shape [axis ] + 1 ). astype ( "floatX" )
303+ n = pt . cast (array .shape [axis ] + 1 , "floatX" )
304304 sum_vals = array .sum (axis , keepdims = True )
305305 norm = sum_vals / (pt .sqrt (n ) + n )
306306 fill_val = norm - sum_vals / pt .sqrt (n )
@@ -312,7 +312,7 @@ def extend_axis(array, axis):
312312 def extend_axis_rev (array , axis ):
313313 normalized_axis = normalize_axis_tuple (axis , array .ndim )[0 ]
314314
315- n = array .shape [normalized_axis ]. astype ( "floatX" )
315+ n = pt . cast ( array .shape [normalized_axis ], "floatX" )
316316 last = pt .take (array , [- 1 ], axis = normalized_axis )
317317
318318 sum_vals = - last * pt .sqrt (n )
Original file line number Diff line number Diff line change @@ -170,6 +170,17 @@ def test_sum_to_1():
170170 )
171171
172172
173+ def test_zerosumtransform ():
174+ zst = tr .ZeroSumTransform ([0 ])
175+
176+ # Check numpy input works, as it is not always converted to pytensor before
177+ # Case where it failed was when setting initvals in model
178+ val = np .array ([1 , 2 , 3 , 4 ])
179+ zval = zst .backward (val )
180+ assert np .allclose (zval .eval ().sum (), 0.0 )
181+ assert np .allclose (zst .forward (zval ).eval (), val )
182+
183+
173184def test_log ():
174185 check_transform (tr .log , Rplusbig )
175186
You can’t perform that action at this time.
0 commit comments