Skip to content

Commit 7790a0c

Browse files
committed
Do not redefine DisconnectedType everytime
1 parent d8b51df commit 7790a0c

File tree

16 files changed

+72
-61
lines changed

16 files changed

+72
-61
lines changed

doc/extending/creating_an_op.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
495495

496496
from pytensor.graph.op import Op
497497
from pytensor.graph.basic import Apply
498-
from pytensor.gradient import DisconnectedType
498+
from pytensor.gradient import DisconnectedType, disconnected_type
499499

500500
class TransposeAndSumOp(Op):
501501
__props__ = ()
@@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
539539
out1_grad, out2_grad = output_grads
540540

541541
if isinstance(out1_grad.type, DisconnectedType):
542-
x_grad = DisconnectedType()()
542+
x_grad = disconnected_type()
543543
else:
544544
# Transpose the last two dimensions of the output gradient
545545
x_grad = pt.swapaxes(out1_grad, -1, -2)
546546

547547
if isinstance(out2_grad.type, DisconnectedType):
548-
y_grad = DisconnectedType()()
548+
y_grad = disconnected_type()
549549
else:
550550
# Broadcast the output gradient to the same shape as y
551551
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)

pytensor/breakpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from pytensor.gradient import DisconnectedType
3+
from pytensor.gradient import disconnected_type
44
from pytensor.graph.basic import Apply, Variable
55
from pytensor.graph.op import Op
66
from pytensor.tensor.basic import as_tensor_variable
@@ -142,7 +142,7 @@ def perform(self, node, inputs, output_storage):
142142
output_storage[i][0] = inputs[i + 1]
143143

144144
def grad(self, inputs, output_gradients):
145-
return [DisconnectedType()(), *output_gradients]
145+
return [disconnected_type(), *output_gradients]
146146

147147
def infer_shape(self, fgraph, inputs, input_shapes):
148148
# Return the shape of every input but the condition (first input)

pytensor/raise_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from textwrap import indent
44

5-
from pytensor.gradient import DisconnectedType
5+
from pytensor.gradient import disconnected_type
66
from pytensor.graph.basic import Apply, Constant, Variable
77
from pytensor.graph.replace import _vectorize_node
88
from pytensor.link.c.op import COp
@@ -89,7 +89,10 @@ def perform(self, node, inputs, outputs):
8989
raise self.exc_type(self.msg)
9090

9191
def grad(self, input, output_gradients):
92-
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
92+
return [
93+
*output_gradients,
94+
*(disconnected_type() for _ in range(len(input) - 1)),
95+
]
9396

9497
def connection_pattern(self, node):
9598
return [[1]] + [[0]] * (len(node.inputs) - 1)

pytensor/scalar/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytensor
2323
from pytensor import printing
2424
from pytensor.configdefaults import config
25-
from pytensor.gradient import DisconnectedType, grad_undefined
25+
from pytensor.gradient import disconnected_type, grad_undefined
2626
from pytensor.graph.basic import Apply, Constant, Variable, clone
2727
from pytensor.graph.fg import FunctionGraph
2828
from pytensor.graph.op import HasInnerGraph
@@ -2426,13 +2426,13 @@ def grad(self, inputs, gout):
24262426
(gz,) = gout
24272427
if y.type in continuous_types:
24282428
# x is disconnected because the elements of x are not used
2429-
return DisconnectedType()(), gz
2429+
return disconnected_type(), gz
24302430
else:
24312431
# when y is discrete, we assume the function can be extended
24322432
# to deal with real-valued inputs by rounding them to the
24332433
# nearest integer. f(x+eps) thus equals f(x) so the gradient
24342434
# is zero, not disconnected or undefined
2435-
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
2435+
return disconnected_type(), y.zeros_like(dtype=config.floatX)
24362436

24372437

24382438
second = Second(name="second")

pytensor/scan/op.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,14 @@
6363
from pytensor.compile.mode import Mode, get_mode
6464
from pytensor.compile.profiling import register_profiler_printer
6565
from pytensor.configdefaults import config
66-
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
66+
from pytensor.gradient import (
67+
DisconnectedType,
68+
NullType,
69+
Rop,
70+
disconnected_type,
71+
grad,
72+
grad_undefined,
73+
)
6774
from pytensor.graph.basic import (
6875
Apply,
6976
Variable,
@@ -3073,7 +3080,7 @@ def compute_all_gradients(known_grads):
30733080
)
30743081
outputs = local_op(*outer_inputs, return_list=True)
30753082
# Re-order the gradients correctly
3076-
gradients = [DisconnectedType()()]
3083+
gradients = [disconnected_type()] # n_steps is disconnected
30773084

30783085
offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
30793086
for p, (x, t) in enumerate(
@@ -3098,7 +3105,7 @@ def compute_all_gradients(known_grads):
30983105
else:
30993106
gradients.append(x[::-1])
31003107
elif t == "disconnected":
3101-
gradients.append(DisconnectedType()())
3108+
gradients.append(disconnected_type())
31023109
elif t == "through_untraced":
31033110
gradients.append(
31043111
grad_undefined(
@@ -3126,7 +3133,7 @@ def compute_all_gradients(known_grads):
31263133
else:
31273134
gradients.append(x[::-1])
31283135
elif t == "disconnected":
3129-
gradients.append(DisconnectedType()())
3136+
gradients.append(disconnected_type())
31303137
elif t == "through_untraced":
31313138
gradients.append(
31323139
grad_undefined(
@@ -3149,15 +3156,15 @@ def compute_all_gradients(known_grads):
31493156
if not isinstance(dC_dout.type, DisconnectedType) and connected:
31503157
disconnected = False
31513158
if disconnected:
3152-
gradients.append(DisconnectedType()())
3159+
gradients.append(disconnected_type())
31533160
else:
31543161
gradients.append(
31553162
grad_undefined(
31563163
self, idx, inputs[idx], "Shared Variable with update"
31573164
)
31583165
)
31593166

3160-
gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
3167+
gradients.extend(disconnected_type() for _ in range(info.n_nit_sot))
31613168
begin = end
31623169

31633170
end = begin + n_sitsot_outs
@@ -3167,7 +3174,7 @@ def compute_all_gradients(known_grads):
31673174
if t == "connected":
31683175
gradients.append(x[-1])
31693176
elif t == "disconnected":
3170-
gradients.append(DisconnectedType()())
3177+
gradients.append(disconnected_type())
31713178
elif t == "through_untraced":
31723179
gradients.append(
31733180
grad_undefined(
@@ -3195,7 +3202,7 @@ def compute_all_gradients(known_grads):
31953202
):
31963203
disconnected = False
31973204
if disconnected:
3198-
gradients[idx] = DisconnectedType()()
3205+
gradients[idx] = disconnected_type()
31993206
return gradients
32003207

32013208
def R_op(self, inputs, eval_points):

pytensor/sparse/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pytensor import _as_symbolic, as_symbolic
1919
from pytensor import scalar as ps
2020
from pytensor.configdefaults import config
21-
from pytensor.gradient import DisconnectedType, grad_undefined
21+
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
2222
from pytensor.graph.basic import Apply, Constant, Variable
2323
from pytensor.graph.op import Op
2424
from pytensor.link.c.type import generic
@@ -480,9 +480,9 @@ def grad(self, inputs, gout):
480480
)
481481
return [
482482
g_data,
483-
DisconnectedType()(),
484-
DisconnectedType()(),
485-
DisconnectedType()(),
483+
disconnected_type(),
484+
disconnected_type(),
485+
disconnected_type(),
486486
]
487487

488488
def infer_shape(self, fgraph, node, shapes):
@@ -1940,7 +1940,7 @@ def grad(self, inputs, grads):
19401940
gx = g_output
19411941
gy = pytensor.tensor.subtensor.advanced_subtensor1(g_output, *idx_list)
19421942

1943-
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
1943+
return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
19441944

19451945

19461946
construct_sparse_from_list = ConstructSparseFromList()

pytensor/tensor/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pytensor import config, printing
2323
from pytensor import scalar as ps
2424
from pytensor.compile.builders import OpFromGraph
25-
from pytensor.gradient import DisconnectedType, grad_undefined
25+
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
2626
from pytensor.graph import RewriteDatabaseQuery
2727
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
2828
from pytensor.graph.fg import FunctionGraph, Output
@@ -1738,7 +1738,7 @@ def grad(self, inputs, grads):
17381738
# the inputs that specify the shape. If you grow the
17391739
# shape by epsilon, the existing elements do not
17401740
# change.
1741-
return [gx] + [DisconnectedType()() for i in inputs[1:]]
1741+
return [gx, *(disconnected_type() for _ in range(len(inputs) - 1))]
17421742

17431743
def R_op(self, inputs, eval_points):
17441744
if eval_points[0] is None:
@@ -2277,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs):
22772277
return [
22782278
join(axis, *new_g_outputs),
22792279
grad_undefined(self, 1, axis),
2280-
DisconnectedType()(),
2280+
disconnected_type(),
22812281
]
22822282

22832283
def R_op(self, inputs, eval_points):
@@ -3340,14 +3340,14 @@ def L_op(self, inputs, outputs, grads):
33403340
if self.dtype in discrete_dtypes:
33413341
return [
33423342
start.zeros_like(dtype=config.floatX),
3343-
DisconnectedType()(),
3343+
disconnected_type(),
33443344
step.zeros_like(dtype=config.floatX),
33453345
]
33463346
else:
33473347
num_steps_taken = outputs[0].shape[0]
33483348
return [
33493349
gz.sum(),
3350-
DisconnectedType()(),
3350+
disconnected_type(),
33513351
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(),
33523352
]
33533353

@@ -4374,7 +4374,7 @@ def connection_pattern(self, node):
43744374
return [[False] for i in node.inputs]
43754375

43764376
def grad(self, inputs, grads):
4377-
return [DisconnectedType()() for i in inputs]
4377+
return [disconnected_type() for _ in range(len(inputs))]
43784378

43794379
def R_op(self, inputs, eval_points):
43804380
return [zeros(inputs, self.dtype)]

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytensor
99
import pytensor.scalar.basic as ps
1010
from pytensor.gradient import (
11-
DisconnectedType,
1211
_float_zeros_like,
1312
disconnected_type,
1413
grad_undefined,
@@ -716,7 +715,7 @@ def grad(self, inputs, gout):
716715
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
717716
gx = ptb.moveaxis(gx_transpose, 0, axis)
718717

719-
return [gx, DisconnectedType()()]
718+
return [gx, disconnected_type()]
720719

721720
def infer_shape(self, fgraph, node, ins_shapes):
722721
i0_shapes = ins_shapes[0]

pytensor/tensor/fft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from pytensor.gradient import DisconnectedType
3+
from pytensor.gradient import disconnected_type
44
from pytensor.graph.basic import Apply
55
from pytensor.graph.op import Op
66
from pytensor.tensor.basic import as_tensor_variable
@@ -59,7 +59,7 @@ def grad(self, inputs, output_grads):
5959
+ [slice(None)]
6060
)
6161
gout = set_subtensor(gout[idx], gout[idx] * 0.5)
62-
return [irfft_op(gout, s), DisconnectedType()()]
62+
return [irfft_op(gout, s), disconnected_type()]
6363

6464
def connection_pattern(self, node):
6565
# Specify that shape input parameter has no connection to graph and gradients.
@@ -121,7 +121,7 @@ def grad(self, inputs, output_grads):
121121
+ [slice(None)]
122122
)
123123
gf = set_subtensor(gf[idx], gf[idx] * 2)
124-
return [gf, DisconnectedType()()]
124+
return [gf, disconnected_type()]
125125

126126
def connection_pattern(self, node):
127127
# Specify that shape input parameter has no connection to graph and gradients.

pytensor/tensor/nlinalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pytensor import scalar as ps
1010
from pytensor.compile.builders import OpFromGraph
11-
from pytensor.gradient import DisconnectedType
11+
from pytensor.gradient import DisconnectedType, disconnected_type
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
1414
from pytensor.tensor import TensorLike
@@ -652,8 +652,8 @@ def s_grad_only(
652652
]
653653
if all(is_disconnected):
654654
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
655-
# graph if its fully disconnected. It is included for completeness.
656-
return [DisconnectedType()()] # pragma: no cover
655+
# graph if it's fully disconnected. It is included for completeness.
656+
return [disconnected_type()] # pragma: no cover
657657

658658
elif is_disconnected == [True, False, True]:
659659
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without

0 commit comments

Comments
 (0)