Skip to content

Commit bf13e7a

Browse files
jacobrgardnerMisha Padidar
andauthored
Kernels can return tuples from num_outputs_per_input (#1849)
* Kernels can now return tuples from num_outputs_per_input * caught a couple fixes that we missed earlier Co-authored-by: Misha Padidar <[email protected]>
1 parent 0e74f2b commit bf13e7a

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

gpytorch/lazy/lazy_evaluated_kernel_tensor.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,18 @@ def _getitem(self, row_index, col_index, *batch_indices):
5454
x1 = self.x1
5555
x2 = self.x2
5656
num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2)
57+
if isinstance(num_outs_per_in, tuple):
58+
num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in
59+
else:
60+
num_outs_per_in_rows = num_outs_per_in
61+
num_outs_per_in_cols = num_outs_per_in
5762

5863
# The row index and col index should exactly correspond to which entries of x1 and x2 we need
5964
# So we'll basically call x1[*batch_indices, row_index, :], x2[*batch_indices, col_index, :]
6065

6166
# However - if we have multiple outputs per input, then the indices won't directly
6267
# correspond to the entries of row/col. We'll have to do a little pre-processing
63-
if num_outs_per_in != 1:
68+
if num_outs_per_in_rows != 1 or num_outs_per_in_cols != 1:
6469
if not isinstance(x1, slice) or not isinstance(x2, slice):
6570
# It's too complicated to deal with tensor indices in this case - we'll use the super method
6671
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
@@ -81,16 +86,16 @@ def _getitem(self, row_index, col_index, *batch_indices):
8186
if row_step is not None or col_step is not None:
8287
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
8388
if (
84-
(row_start % num_outs_per_in)
85-
or (col_start % num_outs_per_in)
86-
or (row_end % num_outs_per_in)
87-
or (col_end % num_outs_per_in)
89+
(row_start % num_outs_per_in_rows)
90+
or (col_start % num_outs_per_in_cols)
91+
or (row_end % num_outs_per_in_rows)
92+
or (col_end % num_outs_per_in_cols)
8893
):
8994
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
9095

9196
# Otherwise - let's divide the slices by the number of outputs per input
92-
row_index = slice(row_start // num_outs_per_in, row_end // num_outs_per_in, None)
93-
col_index = slice(col_start // num_outs_per_in, col_end // num_outs_per_in, None)
97+
row_index = slice(row_start // num_outs_per_in_rows, row_end // num_outs_per_in_rows, None)
98+
col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None)
9499

95100
# Define the index we're using for the last index
96101
# If the last index corresponds to a batch, then we'll use the appropriate batch_index
@@ -220,9 +225,14 @@ def _size(self):
220225

221226
x1 = self.x1
222227
x2 = self.x2
223-
num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2)
224-
num_rows = x1.size(-2) * num_outputs_per_input
225-
num_cols = x2.size(-2) * num_outputs_per_input
228+
num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2)
229+
if isinstance(num_outs_per_in, tuple):
230+
num_outs_per_in_rows, num_outs_per_in_cols = num_outs_per_in
231+
else:
232+
num_outs_per_in_rows = num_outs_per_in
233+
num_outs_per_in_cols = num_outs_per_in
234+
num_rows = x1.size(-2) * num_outs_per_in_rows
235+
num_cols = x2.size(-2) * num_outs_per_in_cols
226236

227237
# Default case - when we're not using broadcasting
228238
# We write this case special for efficiency

0 commit comments

Comments
 (0)