@@ -54,13 +54,18 @@ def _getitem(self, row_index, col_index, *batch_indices):
5454 x1 = self .x1
5555 x2 = self .x2
5656 num_outs_per_in = self .kernel .num_outputs_per_input (x1 , x2 )
57+ if isinstance (num_outs_per_in , tuple ):
58+ num_outs_per_in_rows , num_outs_per_in_cols = num_outs_per_in
59+ else :
60+ num_outs_per_in_rows = num_outs_per_in
61+ num_outs_per_in_cols = num_outs_per_in
5762
5863 # The row index and col index should exactly correspond to which entries of x1 and x2 we need
5964 # So we'll basically call x1[*batch_indices, row_index, :], x2[*batch_indices, col_index, :]
6065
6166 # However - if we have multiple outputs per input, then the indices won't directly
6267 # correspond to the entries of row/col. We'll have to do a little pre-processing
63- if num_outs_per_in != 1 :
68+ if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1 :
6469 if not isinstance (x1 , slice ) or not isinstance (x2 , slice ):
6570 # It's too complicated to deal with tensor indices in this case - we'll use the super method
6671 return self .evaluate_kernel ()._getitem (row_index , col_index , * batch_indices )
@@ -81,16 +86,16 @@ def _getitem(self, row_index, col_index, *batch_indices):
8186 if row_step is not None or col_step is not None :
8287 return self .evaluate_kernel ()._getitem (row_index , col_index , * batch_indices )
8388 if (
84- (row_start % num_outs_per_in )
85- or (col_start % num_outs_per_in )
86- or (row_end % num_outs_per_in )
87- or (col_end % num_outs_per_in )
89+ (row_start % num_outs_per_in_rows )
90+ or (col_start % num_outs_per_in_cols )
91+ or (row_end % num_outs_per_in_rows )
92+ or (col_end % num_outs_per_in_cols )
8893 ):
8994 return self .evaluate_kernel ()._getitem (row_index , col_index , * batch_indices )
9095
9196 # Otherwise - let's divide the slices by the number of outputs per input
92- row_index = slice (row_start // num_outs_per_in , row_end // num_outs_per_in , None )
93- col_index = slice (col_start // num_outs_per_in , col_end // num_outs_per_in , None )
97+ row_index = slice (row_start // num_outs_per_in_rows , row_end // num_outs_per_in_rows , None )
98+ col_index = slice (col_start // num_outs_per_in_cols , col_end // num_outs_per_in_cols , None )
9499
95100 # Define the index we're using for the last index
96101 # If the last index corresponds to a batch, then we'll use the appropriate batch_index
@@ -220,9 +225,14 @@ def _size(self):
220225
221226 x1 = self .x1
222227 x2 = self .x2
223- num_outputs_per_input = self .kernel .num_outputs_per_input (x1 , x2 )
224- num_rows = x1 .size (- 2 ) * num_outputs_per_input
225- num_cols = x2 .size (- 2 ) * num_outputs_per_input
228+ num_outs_per_in = self .kernel .num_outputs_per_input (x1 , x2 )
229+ if isinstance (num_outs_per_in , tuple ):
230+ num_outs_per_in_rows , num_outs_per_in_cols = num_outs_per_in
231+ else :
232+ num_outs_per_in_rows = num_outs_per_in
233+ num_outs_per_in_cols = num_outs_per_in
234+ num_rows = x1 .size (- 2 ) * num_outs_per_in_rows
235+ num_cols = x2 .size (- 2 ) * num_outs_per_in_cols
226236
227237 # Default case - when we're not using broadcasting
228238 # We write this case special for efficiency
0 commit comments