Skip to content

Commit 77f333a

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix Elemwise and Blockwise gradient for Ops with mixed discrete and continuous output types
1 parent a149f6c commit 77f333a

File tree

4 files changed

+53
-45
lines changed

4 files changed

+53
-45
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.scalar import ScalarType
1919
from pytensor.tensor import as_tensor_variable
2020
from pytensor.tensor.shape import shape_padleft
21-
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
21+
from pytensor.tensor.type import TensorType, tensor
2222
from pytensor.tensor.utils import (
2323
_parse_gufunc_signature,
2424
broadcast_static_dim_lengths,
@@ -255,6 +255,10 @@ def as_core(t, core_t):
255255
as_core(ograd, core_ograd)
256256
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
257257
]
258+
# FIXME: These core_outputs do not depend on core_inputs, not pretty
259+
# It's not neccessarily a problem because if they are referenced by the gradient,
260+
# they get replaced later in vectorize. But if the Op was to make any decision
261+
# by introspecting the dependencies of output on inputs it would fail badly!
258262
core_outputs = core_node.outputs
259263

260264
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
@@ -282,27 +286,6 @@ def L_op(self, inputs, outs, ograds):
282286
# Compute grad with respect to broadcasted input
283287
rval = self._bgrad(inputs, outs, ograds)
284288

285-
# TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable
286-
# to the gradient.grad method when the outputs have
287-
# some integer and some floating point outputs
288-
if any(out.type.dtype not in continuous_dtypes for out in outs):
289-
# For integer output, return value may only be zero or undefined
290-
# We don't bother with trying to check that the scalar ops
291-
# correctly returned something that evaluates to 0, we just make
292-
# the return value obviously zero so that gradient.grad can tell
293-
# this op did the right thing.
294-
new_rval = []
295-
for elem, inp in zip(rval, inputs, strict=True):
296-
if isinstance(elem.type, NullType | DisconnectedType):
297-
new_rval.append(elem)
298-
else:
299-
elem = inp.zeros_like()
300-
if str(elem.type.dtype) not in continuous_dtypes:
301-
elem = elem.astype(config.floatX)
302-
assert str(elem.type.dtype) not in discrete_dtypes
303-
new_rval.append(elem)
304-
return new_rval
305-
306289
# Sum out the broadcasted dimensions
307290
batch_ndims = self.batch_ndim(outs[0].owner)
308291
batch_shape = outs[0].type.shape[:batch_ndims]

pytensor/tensor/elemwise.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -515,27 +515,6 @@ def L_op(self, inputs, outs, ograds):
515515
# Compute grad with respect to broadcasted input
516516
rval = self._bgrad(inputs, outs, ograds)
517517

518-
# TODO: make sure that zeros are clearly identifiable
519-
# to the gradient.grad method when the outputs have
520-
# some integer and some floating point outputs
521-
if any(out.type.dtype not in continuous_dtypes for out in outs):
522-
# For integer output, return value may only be zero or undefined
523-
# We don't bother with trying to check that the scalar ops
524-
# correctly returned something that evaluates to 0, we just make
525-
# the return value obviously zero so that gradient.grad can tell
526-
# this op did the right thing.
527-
new_rval = []
528-
for elem, ipt in zip(rval, inputs, strict=True):
529-
if isinstance(elem.type, NullType | DisconnectedType):
530-
new_rval.append(elem)
531-
else:
532-
elem = ipt.zeros_like()
533-
if str(elem.type.dtype) not in continuous_dtypes:
534-
elem = elem.astype(config.floatX)
535-
assert str(elem.type.dtype) not in discrete_dtypes
536-
new_rval.append(elem)
537-
return new_rval
538-
539518
# sum out the broadcasted dimensions
540519
for i, ipt in enumerate(inputs):
541520
if isinstance(rval[i].type, NullType | DisconnectedType):

tests/tensor/test_blockwise.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, tensor
15+
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@@ -596,3 +596,25 @@ def core_scipy_fn(A, b):
596596
# Confirm input was destroyed
597597
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
598598
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
599+
600+
601+
def test_gradient_mixed_discrete_output_core_op():
602+
class MixedDtypeCoreOp(Op):
603+
gufunc_signature = "()->(),()"
604+
itypes = [scalar().type]
605+
otypes = [scalar().type, scalar(dtype=int).type]
606+
607+
def perform(self, node, inputs, outputs):
608+
raise NotImplementedError()
609+
610+
def L_op(self, inputs, outputs, output_gradients):
611+
return [ones_like(inputs[0]) * output_gradients[0]]
612+
613+
op = Blockwise(MixedDtypeCoreOp())
614+
x = vector("x")
615+
y, _ = op(x)
616+
np.testing.assert_array_equal(
617+
grad(y.sum(), x).eval({x: np.full((12,), np.nan)}),
618+
np.ones((12,)),
619+
strict=True,
620+
)

tests/tensor/test_elemwise.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor.scalar as ps
1212
import pytensor.tensor as pt
1313
import tests.unittest_tools as utt
14-
from pytensor import In, Out
14+
from pytensor import grad, In, Out
1515
from pytensor.compile.function import function
1616
from pytensor.compile.mode import Mode
1717
from pytensor.configdefaults import config
@@ -21,6 +21,7 @@
2121
from pytensor.link.basic import PerformLinker
2222
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2323
from pytensor.npy_2_compat import numpy_maxdims
24+
from pytensor.scalar import ScalarOp, float64, int64
2425
from pytensor.tensor import as_tensor_variable
2526
from pytensor.tensor.basic import get_scalar_constant_value, second
2627
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -1068,3 +1069,26 @@ def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
10681069
return careduce_benchmark_tester(
10691070
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
10701071
)
1072+
1073+
1074+
def test_gradient_mixed_discrete_output_scalar_op():
1075+
class MixedDtypeScalarOp(ScalarOp):
1076+
def make_node(self, *inputs):
1077+
inputs = [float64()]
1078+
outputs = [float64(), int64()]
1079+
return Apply(self, inputs, outputs)
1080+
1081+
def perform(self, node, inputs, outputs):
1082+
raise NotImplementedError()
1083+
1084+
def L_op(self, inputs, outputs, output_gradients):
1085+
return [inputs[0].ones_like() * output_gradients[0]]
1086+
1087+
op = Elemwise(MixedDtypeScalarOp())
1088+
x = vector("x")
1089+
y, _ = op(x)
1090+
np.testing.assert_array_equal(
1091+
grad(y.sum(), x).eval({x: np.full((12,), np.nan)}),
1092+
np.ones((12,)),
1093+
strict=True,
1094+
)

0 commit comments

Comments
 (0)