|
18 | 18 | from pytensor.scalar import ScalarType |
19 | 19 | from pytensor.tensor import as_tensor_variable |
20 | 20 | from pytensor.tensor.shape import shape_padleft |
21 | | -from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor |
| 21 | +from pytensor.tensor.type import TensorType, tensor |
22 | 22 | from pytensor.tensor.utils import ( |
23 | 23 | _parse_gufunc_signature, |
24 | 24 | broadcast_static_dim_lengths, |
@@ -256,6 +256,10 @@ def as_core(t, core_t): |
256 | 256 | as_core(ograd, core_ograd) |
257 | 257 | for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True) |
258 | 258 | ] |
| 259 | + # FIXME: These core_outputs do not depend on core_inputs, not pretty |
| 260 | + # It's not neccessarily a problem because if they are referenced by the gradient, |
| 261 | + # they get replaced later in vectorize. But if the Op was to make any decision |
| 262 | + # by introspecting the dependencies of output on inputs it would fail badly! |
259 | 263 | core_outputs = core_node.outputs |
260 | 264 |
|
261 | 265 | core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) |
@@ -283,27 +287,6 @@ def L_op(self, inputs, outs, ograds): |
283 | 287 | # Compute grad with respect to broadcasted input |
284 | 288 | rval = self._bgrad(inputs, outs, ograds) |
285 | 289 |
|
286 | | - # TODO: (Borrowed from Elemwise) make sure that zeros are clearly identifiable |
287 | | - # to the gradient.grad method when the outputs have |
288 | | - # some integer and some floating point outputs |
289 | | - if any(out.type.dtype not in continuous_dtypes for out in outs): |
290 | | - # For integer output, return value may only be zero or undefined |
291 | | - # We don't bother with trying to check that the scalar ops |
292 | | - # correctly returned something that evaluates to 0, we just make |
293 | | - # the return value obviously zero so that gradient.grad can tell |
294 | | - # this op did the right thing. |
295 | | - new_rval = [] |
296 | | - for elem, inp in zip(rval, inputs, strict=True): |
297 | | - if isinstance(elem.type, NullType | DisconnectedType): |
298 | | - new_rval.append(elem) |
299 | | - else: |
300 | | - elem = inp.zeros_like() |
301 | | - if str(elem.type.dtype) not in continuous_dtypes: |
302 | | - elem = elem.astype(config.floatX) |
303 | | - assert str(elem.type.dtype) not in discrete_dtypes |
304 | | - new_rval.append(elem) |
305 | | - return new_rval |
306 | | - |
307 | 290 | # Sum out the broadcasted dimensions |
308 | 291 | batch_ndims = self.batch_ndim(outs[0].owner) |
309 | 292 | batch_shape = outs[0].type.shape[:batch_ndims] |
|
0 commit comments