Skip to content

Commit df018d0

Browse files
valtrongpleiss
andauthored
Fix LazyEvaluatedKernelTensor._unsqueeze_batch (#1828)
Different implementation of `LazyEvaluatedKernelTensor._unsqueeze_batch` Co-authored-by: Geoff Pleiss <[email protected]>
1 parent 19e67f3 commit df018d0

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

gpytorch/lazy/lazy_evaluated_kernel_tensor.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,16 @@ def _getitem(self, row_index, col_index, *batch_indices):
6868
# Now we know that x1 and x2 are slices
6969
# Let's make sure that the slice dimensions perfectly correspond with the number of
7070
# outputs per input that we have
71-
row_start, row_end, row_step = row_index.start, row_index.stop, row_index.step
72-
col_start, col_end, col_step = col_index.start, col_index.stop, col_index.step
71+
row_start, row_end, row_step = (
72+
row_index.start,
73+
row_index.stop,
74+
row_index.step,
75+
)
76+
col_start, col_end, col_step = (
77+
col_index.start,
78+
col_index.stop,
79+
col_index.step,
80+
)
7381
if row_step is not None or col_step is not None:
7482
return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices)
7583
if (
@@ -127,7 +135,7 @@ def _getitem(self, row_index, col_index, *batch_indices):
127135
new_kernel = self.kernel.__getitem__(batch_indices)
128136

129137
# Now construct a kernel with those indices
130-
return self.__class__(x1, x2, kernel=new_kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params)
138+
return self.__class__(x1, x2, kernel=new_kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
131139

132140
def _matmul(self, rhs):
133141
# This _matmul is defined computes the kernel in chunks
@@ -148,7 +156,7 @@ def _matmul(self, rhs):
148156
res = []
149157
for sub_x1 in sub_x1s:
150158
sub_kernel_matrix = lazify(
151-
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
159+
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
152160
)
153161
res.append(sub_kernel_matrix._matmul(rhs))
154162

@@ -177,7 +185,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
177185
sub_x1.requires_grad_(True)
178186
with torch.enable_grad(), settings.lazily_evaluate_kernels(False):
179187
sub_kernel_matrix = lazify(
180-
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
188+
self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
181189
)
182190
sub_grad_outputs = tuple(sub_kernel_matrix._quad_form_derivative(sub_left_vecs, right_vecs))
183191
sub_kernel_outputs = tuple(sub_kernel_matrix.representation())
@@ -230,14 +238,16 @@ def _size(self):
230238

231239
def _transpose_nonbatch(self):
232240
return self.__class__(
233-
self.x2, self.x1, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params
241+
self.x2, self.x1, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,
234242
)
235243

236244
def add_jitter(self, jitter_val=1e-3):
237245
return self.evaluate_kernel().add_jitter(jitter_val)
238246

239247
def _unsqueeze_batch(self, dim):
240-
return self[(slice(None),) * dim + (None,)]
248+
x1 = self.x1.unsqueeze(dim)
249+
x2 = self.x2.unsqueeze(dim)
250+
return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
241251

242252
@cached(name="kernel_diag")
243253
def diag(self):
@@ -281,7 +291,7 @@ def evaluate_kernel(self):
281291
with settings.lazily_evaluate_kernels(False):
282292
temp_active_dims = self.kernel.active_dims
283293
self.kernel.active_dims = None
284-
res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
294+
res = self.kernel(x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
285295
self.kernel.active_dims = temp_active_dims
286296

287297
# Check the size of the output
@@ -305,7 +315,7 @@ def repeat(self, *repeats):
305315

306316
x1 = self.x1.repeat(*batch_repeat, row_repeat, 1)
307317
x2 = self.x2.repeat(*batch_repeat, col_repeat, 1)
308-
return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params)
318+
return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params,)
309319

310320
def representation(self):
311321
# If we're checkpointing the kernel, we'll use chunked _matmuls defined in LazyEvaluatedKernelTensor

0 commit comments

Comments
 (0)