Skip to content

Commit eec70f9

Browse files
authored
Fix dtype/device mismatch in _get_indices() (#90)
1 parent 91523ec commit eec70f9

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

linear_operator/operators/diag_linear_operator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "..
7070

7171
def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
7272
res = self._diag[(*batch_indices, row_index)]
73+
# Unify device and dtype prior to comparison
74+
row_index = row_index.to(device=res.device, dtype=res.dtype)
75+
col_index = col_index.to(device=res.device, dtype=res.dtype)
7376
# If row and col index don't agree, then we have off diagonal elements
7477
# Those should be zero'd out
75-
res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype)
78+
res = res * torch.eq(row_index, col_index)
7679
return res
7780

7881
def _mul_constant(

0 commit comments

Comments
 (0)