Skip to content

Commit 55332f4

Browse files
fix: get device for operations in torch backend (#469)
1 parent 166177d commit 55332f4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

optiland/backend/torch_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,12 +835,12 @@ def wrapped(x: Tensor) -> Tensor:
835835
# Conversion and Utilities
836836
# --------------------------
837837
def atleast_1d(x: ArrayLike) -> Tensor:
838-
t = torch.as_tensor(x, dtype=get_precision())
838+
t = torch.as_tensor(x, dtype=get_precision(), device=get_device())
839839
return t.unsqueeze(0) if t.ndim == 0 else t
840840

841841

842842
def atleast_2d(x: ArrayLike) -> Tensor:
843-
t = torch.as_tensor(x, dtype=get_precision())
843+
t = torch.as_tensor(x, dtype=get_precision(), device=get_device())
844844
if t.ndim == 0:
845845
return t.unsqueeze(0).unsqueeze(0)
846846
if t.ndim == 1:

0 commit comments

Comments
 (0)