Skip to content

Commit 468397e

Browse files
committed
bugfix #66
1 parent cf093ac commit 468397e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

fdtd/backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,21 +322,21 @@ def numpy(self, arr):
322322
class TorchCudaBackend(TorchBackend):
323323
"""Torch Cuda Backend"""
324324

325-
def ones(self, shape):
325+
def ones(self, shape, **kwargs):
326326
"""create an array filled with ones"""
327-
return torch.ones(shape, device="cuda")
327+
return torch.ones(shape, device="cuda", **kwargs)
328328

329-
def zeros(self, shape):
329+
def zeros(self, shape, **kwargs):
330330
"""create an array filled with zeros"""
331-
return torch.zeros(shape, device="cuda")
331+
return torch.zeros(shape, device="cuda", **kwargs)
332332

333-
def array(self, arr, dtype=None):
333+
def array(self, arr, dtype=None, **kwargs):
334334
"""create an array from an array-like sequence"""
335335
if dtype is None:
336336
dtype = torch.get_default_dtype()
337337
if torch.is_tensor(arr):
338-
return arr.clone().to(device="cuda", dtype=dtype)
339-
return torch.tensor(arr, device="cuda", dtype=dtype)
338+
return arr.clone().to(device="cuda", dtype=dtype, **kwargs)
339+
return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs)
340340

341341
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
342342
# The same warning applies here.

0 commit comments

Comments
 (0)