@@ -68,8 +68,16 @@ def _getitem(self, row_index, col_index, *batch_indices):
6868 # Now we know that x1 and x2 are slices
6969 # Let's make sure that the slice dimensions perfectly correspond with the number of
7070 # outputs per input that we have
71- row_start , row_end , row_step = row_index .start , row_index .stop , row_index .step
72- col_start , col_end , col_step = col_index .start , col_index .stop , col_index .step
71+ row_start , row_end , row_step = (
72+ row_index .start ,
73+ row_index .stop ,
74+ row_index .step ,
75+ )
76+ col_start , col_end , col_step = (
77+ col_index .start ,
78+ col_index .stop ,
79+ col_index .step ,
80+ )
7381 if row_step is not None or col_step is not None :
7482 return self .evaluate_kernel ()._getitem (row_index , col_index , * batch_indices )
7583 if (
@@ -127,7 +135,7 @@ def _getitem(self, row_index, col_index, *batch_indices):
127135 new_kernel = self .kernel .__getitem__ (batch_indices )
128136
129137 # Now construct a kernel with those indices
130- return self .__class__ (x1 , x2 , kernel = new_kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params )
138+ return self .__class__ (x1 , x2 , kernel = new_kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params , )
131139
132140 def _matmul (self , rhs ):
133141 # This _matmul is defined computes the kernel in chunks
@@ -148,7 +156,7 @@ def _matmul(self, rhs):
148156 res = []
149157 for sub_x1 in sub_x1s :
150158 sub_kernel_matrix = lazify (
151- self .kernel (sub_x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params )
159+ self .kernel (sub_x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params , )
152160 )
153161 res .append (sub_kernel_matrix ._matmul (rhs ))
154162
@@ -177,7 +185,7 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
177185 sub_x1 .requires_grad_ (True )
178186 with torch .enable_grad (), settings .lazily_evaluate_kernels (False ):
179187 sub_kernel_matrix = lazify (
180- self .kernel (sub_x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params )
188+ self .kernel (sub_x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params , )
181189 )
182190 sub_grad_outputs = tuple (sub_kernel_matrix ._quad_form_derivative (sub_left_vecs , right_vecs ))
183191 sub_kernel_outputs = tuple (sub_kernel_matrix .representation ())
@@ -230,14 +238,16 @@ def _size(self):
230238
231239 def _transpose_nonbatch (self ):
232240 return self .__class__ (
233- self .x2 , self .x1 , kernel = self .kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params
241+ self .x2 , self .x1 , kernel = self .kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params ,
234242 )
235243
236244 def add_jitter (self , jitter_val = 1e-3 ):
237245 return self .evaluate_kernel ().add_jitter (jitter_val )
238246
239247 def _unsqueeze_batch (self , dim ):
240- return self [(slice (None ),) * dim + (None ,)]
248+ x1 = self .x1 .unsqueeze (dim )
249+ x2 = self .x2 .unsqueeze (dim )
250+ return self .__class__ (x1 , x2 , kernel = self .kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params ,)
241251
242252 @cached (name = "kernel_diag" )
243253 def diag (self ):
@@ -281,7 +291,7 @@ def evaluate_kernel(self):
281291 with settings .lazily_evaluate_kernels (False ):
282292 temp_active_dims = self .kernel .active_dims
283293 self .kernel .active_dims = None
284- res = self .kernel (x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params )
294+ res = self .kernel (x1 , x2 , diag = False , last_dim_is_batch = self .last_dim_is_batch , ** self .params , )
285295 self .kernel .active_dims = temp_active_dims
286296
287297 # Check the size of the output
@@ -305,7 +315,7 @@ def repeat(self, *repeats):
305315
306316 x1 = self .x1 .repeat (* batch_repeat , row_repeat , 1 )
307317 x2 = self .x2 .repeat (* batch_repeat , col_repeat , 1 )
308- return self .__class__ (x1 , x2 , kernel = self .kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params )
318+ return self .__class__ (x1 , x2 , kernel = self .kernel , last_dim_is_batch = self .last_dim_is_batch , ** self .params , )
309319
310320 def representation (self ):
311321 # If we're checkpointing the kernel, we'll use chunked _matmuls defined in LazyEvaluatedKernelTensor
0 commit comments