Skip to content

Commit afe908c

Browse files
committed
Fix Elemwise and Blockwise gradient for Ops with mixed discrete and continuous output types
1 parent 51ea1a0 commit afe908c

File tree

4 files changed

+53
-44
lines changed

4 files changed

+53
-44
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor.scalar import ScalarType
1919
from pytensor.tensor import as_tensor_variable
2020
from pytensor.tensor.shape import shape_padleft
21-
from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor
21+
from pytensor.tensor.type import TensorType, tensor
2222
from pytensor.tensor.utils import (
2323
_parse_gufunc_signature,
2424
broadcast_static_dim_lengths,
@@ -255,6 +255,10 @@ def as_core(t, core_t):
255255
as_core(ograd, core_ograd)
256256
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
257257
]
258+
# FIXME: These core_outputs do not depend on core_inputs, not pretty
259+
# It's not neccessarily a problem because if they are referenced by the gradient,
260+
# they get replaced later in vectorize. But if the Op was to make any decision
261+
# by introspecting the dependencies of output on inputs it would fail badly!
258262
core_outputs = core_node.outputs
259263

260264
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
@@ -282,27 +286,6 @@ def L_op(self, inputs, outs, ograds):
282286
# Compute grad with respect to broadcasted input
283287
rval = self._bgrad(inputs, outs, ograds)
284288

285-
# TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable
286-
# to the gradient.grad method when the outputs have
287-
# some integer and some floating point outputs
288-
if any(out.type.dtype not in continuous_dtypes for out in outs):
289-
# For integer output, return value may only be zero or undefined
290-
# We don't bother with trying to check that the scalar ops
291-
# correctly returned something that evaluates to 0, we just make
292-
# the return value obviously zero so that gradient.grad can tell
293-
# this op did the right thing.
294-
new_rval = []
295-
for elem, inp in zip(rval, inputs, strict=True):
296-
if isinstance(elem.type, NullType | DisconnectedType):
297-
new_rval.append(elem)
298-
else:
299-
elem = inp.zeros_like()
300-
if str(elem.type.dtype) not in continuous_dtypes:
301-
elem = elem.astype(config.floatX)
302-
assert str(elem.type.dtype) not in discrete_dtypes
303-
new_rval.append(elem)
304-
return new_rval
305-
306289
# Sum out the broadcasted dimensions
307290
batch_ndims = self.batch_ndim(outs[0].owner)
308291
batch_shape = outs[0].type.shape[:batch_ndims]

pytensor/tensor/elemwise.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -520,27 +520,6 @@ def L_op(self, inputs, outs, ograds):
520520
# Compute grad with respect to broadcasted input
521521
rval = self._bgrad(inputs, outs, ograds)
522522

523-
# TODO: make sure that zeros are clearly identifiable
524-
# to the gradient.grad method when the outputs have
525-
# some integer and some floating point outputs
526-
if any(out.type.dtype not in continuous_dtypes for out in outs):
527-
# For integer output, return value may only be zero or undefined
528-
# We don't bother with trying to check that the scalar ops
529-
# correctly returned something that evaluates to 0, we just make
530-
# the return value obviously zero so that gradient.grad can tell
531-
# this op did the right thing.
532-
new_rval = []
533-
for elem, ipt in zip(rval, inputs, strict=True):
534-
if isinstance(elem.type, NullType | DisconnectedType):
535-
new_rval.append(elem)
536-
else:
537-
elem = ipt.zeros_like()
538-
if str(elem.type.dtype) not in continuous_dtypes:
539-
elem = elem.astype(config.floatX)
540-
assert str(elem.type.dtype) not in discrete_dtypes
541-
new_rval.append(elem)
542-
return new_rval
543-
544523
# sum out the broadcasted dimensions
545524
for i, ipt in enumerate(inputs):
546525
if isinstance(rval[i].type, NullType | DisconnectedType):

tests/tensor/test_blockwise.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, tensor
15+
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@@ -596,3 +596,25 @@ def core_scipy_fn(A, b):
596596
# Confirm input was destroyed
597597
assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0])
598598
assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1])
599+
600+
601+
def test_gradient_mixed_discrete_output_core_op():
602+
class MixedDtypeCoreOp(Op):
603+
gufunc_signature = "()->(),()"
604+
itypes = [scalar().type]
605+
otypes = [scalar().type, scalar(dtype=int).type]
606+
607+
def perform(self, node, inputs, outputs):
608+
raise NotImplementedError()
609+
610+
def L_op(self, inputs, outputs, output_gradients):
611+
return [ones_like(inputs[0]) * output_gradients[0]]
612+
613+
op = Blockwise(MixedDtypeCoreOp())
614+
x = vector("x")
615+
y, _ = op(x)
616+
np.testing.assert_array_equal(
617+
grad(y.sum(), x).eval({x: np.full((12,), np.nan)}),
618+
np.ones((12,)),
619+
strict=True,
620+
)

tests/tensor/test_elemwise.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.scalar as ps
1111
import pytensor.tensor as pt
1212
import tests.unittest_tools as utt
13+
from pytensor import grad
1314
from pytensor.compile.function import function
1415
from pytensor.compile.mode import Mode
1516
from pytensor.configdefaults import config
@@ -19,6 +20,7 @@
1920
from pytensor.link.basic import PerformLinker
2021
from pytensor.link.c.basic import CLinker, OpWiseCLinker
2122
from pytensor.npy_2_compat import numpy_maxdims
23+
from pytensor.scalar import ScalarOp, float64, int64
2224
from pytensor.tensor import as_tensor_variable
2325
from pytensor.tensor.basic import get_scalar_constant_value, second
2426
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -1035,3 +1037,26 @@ def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
10351037
return careduce_benchmark_tester(
10361038
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
10371039
)
1040+
1041+
1042+
def test_gradient_mixed_discrete_output_scalar_op():
1043+
class MixedDtypeScalarOp(ScalarOp):
1044+
def make_node(self, *inputs):
1045+
inputs = [float64()]
1046+
outputs = [float64(), int64()]
1047+
return Apply(self, inputs, outputs)
1048+
1049+
def perform(self, node, inputs, outputs):
1050+
raise NotImplementedError()
1051+
1052+
def L_op(self, inputs, outputs, output_gradients):
1053+
return [inputs[0].ones_like() * output_gradients[0]]
1054+
1055+
op = Elemwise(MixedDtypeScalarOp())
1056+
x = vector("x")
1057+
y, _ = op(x)
1058+
np.testing.assert_array_equal(
1059+
grad(y.sum(), x).eval({x: np.full((12,), np.nan)}),
1060+
np.ones((12,)),
1061+
strict=True,
1062+
)

0 commit comments

Comments
 (0)