Skip to content

Commit 17ed99a

Browse files
authored
Merge pull request #1712 from wjmaddox/cat_rows_fix
Add device settings to cat_rows
2 parents c074c2f + a8e197b commit 17ed99a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,9 @@ def cat_rows(self, cross_mat, new_mat, generate_roots=True, **root_decomp_kwargs
780780
A = self
781781

782782
# form matrix C = [A B; B^T D], where A = self, B = cross_mat, D = new_mat
783-
upper_row = CatLazyTensor(A, B, dim=-2)
784-
lower_row = CatLazyTensor(B.transpose(-1, -2), D, dim=-2)
785-
new_lazy_tensor = CatLazyTensor(upper_row, lower_row, dim=-1)
783+
upper_row = CatLazyTensor(A, B, dim=-2, output_device=A.device)
784+
lower_row = CatLazyTensor(B.transpose(-1, -2), D, dim=-2, output_device=A.device)
785+
new_lazy_tensor = CatLazyTensor(upper_row, lower_row, dim=-1, output_device=A.device)
786786

787787
# if the old lazy tensor does not have either a root decomposition or a root inverse decomposition
788788
# don't create one

0 commit comments

Comments
 (0)