5
5
6
6
from pytensor import Variable
7
7
from pytensor .compile import optdb
8
- from pytensor .graph import Constant , FunctionGraph , node_rewriter
8
+ from pytensor .graph import Constant , FunctionGraph , node_rewriter , vectorize_graph
9
9
from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
10
10
from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
11
11
from pytensor .scalar import basic as ps
@@ -119,21 +119,48 @@ def local_subtensor_of_dot(fgraph, node):
119
119
the remaining entries of ``idxs`` (if any), modified to skip the
120
120
second-to-last dimension of ``B`` (because dot sums over this dimension).
121
121
"""
122
- if not isinstance (node .op , Subtensor ):
123
- return
124
- if not (node .inputs [0 ].owner and isinstance (node .inputs [0 ].owner .op , Dot )):
122
+ x , * idx_vars = node .inputs
123
+ if not (
124
+ x .owner is not None
125
+ and (
126
+ isinstance (x .owner .op , Dot )
127
+ or (
128
+ isinstance (x .owner .op , Blockwise )
129
+ and isinstance (x .owner .op .core_op , Dot )
130
+ )
131
+ )
132
+ ):
125
133
return
126
134
# If there is other node that use the outputs of the dot
127
135
# We don't want to compute twice the sub part.
128
- if len (fgraph .clients [node . inputs [ 0 ] ]) > 1 :
136
+ if len (fgraph .clients [x ]) > 1 :
129
137
return
130
138
131
- a = node .inputs [0 ].owner .inputs [0 ]
132
- b = node .inputs [0 ].owner .inputs [1 ]
139
+ a = x .owner .inputs [0 ]
140
+ b = x .owner .inputs [1 ]
141
+ idx_list = indices_from_subtensor (idx_vars , node .op .idx_list )
133
142
134
- idx_list = get_idx_list (node .inputs , node .op .idx_list )
143
+ if not idx_list :
144
+ # Nothing to do, `local_useless_slice` will handle this case
145
+ return None
135
146
136
- num_a_indices = min (a .ndim - 1 , len (idx_list ))
147
+ batch_ndim = (
148
+ x .owner .op .batch_ndim (x .owner ) if isinstance (x .owner .op , Blockwise ) else 0
149
+ )
150
+
151
+ if batch_ndim :
152
+ batch_idx_list , idx_list = idx_list [:batch_ndim ], idx_list [batch_ndim :]
153
+ if not idx_list :
154
+ # Indexing only over batch dimensions of Blockwise, nothing to do here
155
+ # This will be handled by `local_subtensor_of_batch_dims`
156
+ return None
157
+ # We perform the rest of the rewrite on dummy a, b that correspond to the core case
158
+ a = a .type .clone (shape = a .type .shape [batch_ndim :])()
159
+ b = b .type .clone (shape = b .type .shape [batch_ndim :])()
160
+
161
+ a_ndim = a .ndim
162
+ b_ndim = b .ndim
163
+ num_a_indices = min (a_ndim - 1 , len (idx_list ))
137
164
a_indices = idx_list [:num_a_indices ]
138
165
b_indices = idx_list [num_a_indices :]
139
166
@@ -142,26 +169,22 @@ def local_subtensor_of_dot(fgraph, node):
142
169
# This wasn't necessary for a, because we just omitted the last index.
143
170
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144
171
# (dot also handles b.ndim < 2 as a special case)
145
- if b . ndim > 1 and len (b_indices ) >= b . ndim - 1 :
172
+ if b_ndim > 1 and len (b_indices ) >= b_ndim - 1 :
146
173
b_indices = (
147
- b_indices [: b . ndim - 2 ]
174
+ b_indices [: b_ndim - 2 ]
148
175
+ (slice (None , None , None ),)
149
- + b_indices [b . ndim - 2 :]
176
+ + b_indices [b_ndim - 2 :]
150
177
)
151
178
152
- a_sub = a .__getitem__ (tuple (a_indices ))
153
- b_sub = b .__getitem__ (tuple (b_indices )) if b_indices else b
179
+ a_sub = a [tuple (a_indices )]
180
+ b_sub = b [tuple (b_indices )] if b_indices else b
181
+ r = dot (a_sub , b_sub )
154
182
155
- # Copy over previous output stacktrace to a_sub and b_sub,
156
- # because an error in the subtensor operation (e.g. an index error)
157
- # on either a or b must correspond to an error in the
158
- # subtensor operation on their dot product.
159
- copy_stack_trace (node .outputs [0 ], [a_sub , b_sub ])
183
+ if batch_ndim :
184
+ # Replace dummy inputs by the original batch ones
185
+ r = vectorize_graph (r , replace = {a : x .owner .inputs [0 ], b : x .owner .inputs [1 ]})
186
+ r = r [tuple (batch_idx_list )]
160
187
161
- # Copy over previous output stacktrace and previous dot product stacktrace,
162
- # because an error here may correspond to an either in either the original
163
- # dot product, or in the dot product after the subtensor operation.
164
- r = dot (a_sub , b_sub )
165
188
copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], r )
166
189
167
190
return [r ]
0 commit comments