5
5
6
6
from pytensor import Variable
7
7
from pytensor .compile import optdb
8
- from pytensor .graph import Constant , FunctionGraph , node_rewriter
8
+ from pytensor .graph import Constant , FunctionGraph , node_rewriter , vectorize_graph
9
9
from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
10
10
from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
11
11
from pytensor .scalar import basic as ps
@@ -119,21 +119,53 @@ def local_subtensor_of_dot(fgraph, node):
119
119
the remaining entries of ``idxs`` (if any), modified to skip the
120
120
second-to-last dimension of ``B`` (because dot sums over this dimension).
121
121
"""
122
- if not isinstance (node .op , Subtensor ):
123
- return
124
- if not (node .inputs [0 ].owner and isinstance (node .inputs [0 ].owner .op , Dot )):
122
+ x , * idx_vars = node .inputs
123
+ if not (
124
+ x .owner is not None
125
+ and (
126
+ isinstance (x .owner .op , Dot )
127
+ or (
128
+ isinstance (x .owner .op , Blockwise )
129
+ and isinstance (x .owner .op .core_op , Dot )
130
+ )
131
+ )
132
+ ):
125
133
return
126
134
# If there is other node that use the outputs of the dot
127
135
# We don't want to compute twice the sub part.
128
- if len (fgraph .clients [node . inputs [ 0 ] ]) > 1 :
136
+ if len (fgraph .clients [x ]) > 1 :
129
137
return
130
138
131
- a = node .inputs [0 ].owner .inputs [0 ]
132
- b = node .inputs [0 ].owner .inputs [1 ]
139
+ a = x .owner .inputs [0 ]
140
+ b = x .owner .inputs [1 ]
141
+ idx_list = indices_from_subtensor (idx_vars , node .op .idx_list )
133
142
134
- idx_list = get_idx_list (node .inputs , node .op .idx_list )
143
+ if not idx_list :
144
+ # Nothing to do, `local_useless_slice` will handle this case
145
+ return None
135
146
136
- num_a_indices = min (a .ndim - 1 , len (idx_list ))
147
+ batch_ndim = (
148
+ x .owner .op .batch_ndim (x .owner ) if isinstance (x .owner .op , Blockwise ) else 0
149
+ )
150
+
151
+ if batch_ndim :
152
+ batch_idx_list , idx_list = idx_list [:batch_ndim ], idx_list [batch_ndim :]
153
+ # TODO: We want to replace the two ifs below, but due to rewrite ordering and the location of blas rewrites
154
+ # between canonicalization and specialization, we cannot do it yet. Otherwise we miss some optimizations,
155
+ # as the one tested in `test_benchmark_partial_jacobian`
156
+ # if batch_idx_list:
157
+ # # Indexing over batch dimensions of Blockwise. Allow `local_subtensor_of_batch_dims` to handle those first
158
+ # return None
159
+ if not idx_list :
160
+ # Indexing only over batch dimensions of Blockwise, that can be handled by `local_subtensor_of_batch_dims`
161
+ return None
162
+ # We perform the rest of the rewrite on dummy a, b that correspond to the core case
163
+ a = a .type .clone (shape = a .type .shape [batch_ndim :])()
164
+ b = b .type .clone (shape = b .type .shape [batch_ndim :])()
165
+
166
+ a_ndim = a .ndim
167
+ b_ndim = b .ndim
168
+ num_a_indices = min (a_ndim - 1 , len (idx_list ))
137
169
a_indices = idx_list [:num_a_indices ]
138
170
b_indices = idx_list [num_a_indices :]
139
171
@@ -142,26 +174,22 @@ def local_subtensor_of_dot(fgraph, node):
142
174
# This wasn't necessary for a, because we just omitted the last index.
143
175
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144
176
# (dot also handles b.ndim < 2 as a special case)
145
- if b . ndim > 1 and len (b_indices ) >= b . ndim - 1 :
177
+ if b_ndim > 1 and len (b_indices ) >= b_ndim - 1 :
146
178
b_indices = (
147
- b_indices [: b . ndim - 2 ]
179
+ b_indices [: b_ndim - 2 ]
148
180
+ (slice (None , None , None ),)
149
- + b_indices [b . ndim - 2 :]
181
+ + b_indices [b_ndim - 2 :]
150
182
)
151
183
152
- a_sub = a .__getitem__ (tuple (a_indices ))
153
- b_sub = b .__getitem__ (tuple (b_indices )) if b_indices else b
184
+ a_sub = a [tuple (a_indices )]
185
+ b_sub = b [tuple (b_indices )] if b_indices else b
186
+ r = dot (a_sub , b_sub )
154
187
155
- # Copy over previous output stacktrace to a_sub and b_sub,
156
- # because an error in the subtensor operation (e.g. an index error)
157
- # on either a or b must correspond to an error in the
158
- # subtensor operation on their dot product.
159
- copy_stack_trace (node .outputs [0 ], [a_sub , b_sub ])
188
+ if batch_ndim :
189
+ # Replace dummy inputs by the original batch ones
190
+ r = vectorize_graph (r , replace = {a : x .owner .inputs [0 ], b : x .owner .inputs [1 ]})
191
+ r = r [tuple (batch_idx_list )]
160
192
161
- # Copy over previous output stacktrace and previous dot product stacktrace,
162
- # because an error here may correspond to an either in either the original
163
- # dot product, or in the dot product after the subtensor operation.
164
- r = dot (a_sub , b_sub )
165
193
copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], r )
166
194
167
195
return [r ]
0 commit comments