Skip to content

Commit e58618f

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix Elemwise and Blockwise gradient for Ops with mixed discrete and continuous output types
1 parent 676296c commit e58618f

File tree

3 files changed

+32
-44
lines changed

3 files changed

+32
-44
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.scalar import ScalarType
1919
from pytensor.tensor import as_tensor_variable
2020
from pytensor.tensor.shape import shape_padleft
21-
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
21+
from pytensor.tensor.type import TensorType, tensor
2222
from pytensor.tensor.utils import (
2323
_parse_gufunc_signature,
2424
broadcast_static_dim_lengths,
@@ -256,6 +256,10 @@ def as_core(t, core_t):
256256
as_core(ograd, core_ograd)
257257
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
258258
]
259+
# FIXME: These core_outputs do not depend on core_inputs, not pretty
260+
# It's not neccessarily a problem because if they are referenced by the gradient,
261+
# they get replaced later in vectorize. But if the Op was to make any decision
262+
# by introspecting the dependencies of output on inputs it would fail badly!
259263
core_outputs = core_node.outputs
260264

261265
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
@@ -283,27 +287,6 @@ def L_op(self, inputs, outs, ograds):
283287
# Compute grad with respect to broadcasted input
284288
rval = self._bgrad(inputs, outs, ograds)
285289

286-
# TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable
287-
# to the gradient.grad method when the outputs have
288-
# some integer and some floating point outputs
289-
if any(out.type.dtype not in continuous_dtypes for out in outs):
290-
# For integer output, return value may only be zero or undefined
291-
# We don't bother with trying to check that the scalar ops
292-
# correctly returned something that evaluates to 0, we just make
293-
# the return value obviously zero so that gradient.grad can tell
294-
# this op did the right thing.
295-
new_rval = []
296-
for elem, inp in zip(rval, inputs, strict=True):
297-
if isinstance(elem.type, NullType | DisconnectedType):
298-
new_rval.append(elem)
299-
else:
300-
elem = inp.zeros_like()
301-
if str(elem.type.dtype) not in continuous_dtypes:
302-
elem = elem.astype(config.floatX)
303-
assert str(elem.type.dtype) not in discrete_dtypes
304-
new_rval.append(elem)
305-
return new_rval
306-
307290
# Sum out the broadcasted dimensions
308291
batch_ndims = self.batch_ndim(outs[0].owner)
309292
batch_shape = outs[0].type.shape[:batch_ndims]

pytensor/tensor/elemwise.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,6 @@ def L_op(self, inputs, outs, ograds):
515515
# Compute grad with respect to broadcasted input
516516
rval = self._bgrad(inputs, outs, ograds)
517517

518-
# TODO: make sure that zeros are clearly identifiable
519-
# to the gradient.grad method when the outputs have
520-
# some integer and some floating point outputs
521-
if any(out.type.dtype not in continuous_dtypes for out in outs):
522-
# For integer output, return value may only be zero or undefined
523-
# We don't bother with trying to check that the scalar ops
524-
# correctly returned something that evaluates to 0, we just make
525-
# the return value obviously zero so that gradient.grad can tell
526-
# this op did the right thing.
527-
new_rval = []
528-
for elem, ipt in zip(rval, inputs, strict=True):
529-
if isinstance(elem.type, NullType | DisconnectedType):
530-
new_rval.append(elem)
531-
else:
532-
elem = ipt.zeros_like()
533-
if str(elem.type.dtype) not in continuous_dtypes:
534-
elem = elem.astype(config.floatX)
535-
assert str(elem.type.dtype) not in discrete_dtypes
536-
new_rval.append(elem)
537-
return new_rval
538-
539518
# sum out the broadcasted dimensions
540519
for i, ipt in enumerate(inputs):
541520
if isinstance(rval[i].type, NullType | DisconnectedType):

tests/tensor/test_blockwise.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, tensor
15+
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@@ -28,6 +28,9 @@
2828
from pytensor.tensor.utils import _parse_gufunc_signature
2929

3030

31+
config.floatX = "float32"
32+
33+
3134
def test_perform_method_per_node():
3235
"""Confirm that Blockwise uses one perform method per node.
3336
@@ -603,3 +606,26 @@ def core_scipy_fn(A, b):
603606
# Confirm input was destroyed
604607
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
605608
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
609+
610+
611+
def test_gradient_mixed_discrete_output_core_op():
612+
class MixedDtypeCoreOp(Op):
613+
gufunc_signature = "()->(),()"
614+
itypes = [scalar().type]
615+
otypes = [scalar().type, scalar(dtype=int).type]
616+
617+
def perform(self, node, inputs, outputs):
618+
raise NotImplementedError()
619+
620+
def L_op(self, inputs, outputs, output_gradients):
621+
return [ones_like(inputs[0]) * output_gradients[0]]
622+
623+
op = Blockwise(MixedDtypeCoreOp())
624+
x = vector("x")
625+
y, _ = op(x)
626+
627+
np.testing.assert_array_equal(
628+
grad(y.sum(), x).eval({x: np.full(12, np.nan, dtype=config.floatX)}),
629+
np.ones(12, dtype=config.floatX),
630+
strict=True,
631+
)

0 commit comments

Comments
 (0)