@@ -322,21 +322,21 @@ def numpy(self, arr):
322322 class TorchCudaBackend (TorchBackend ):
323323 """Torch Cuda Backend"""
324324
325- def ones (self , shape ):
325+ def ones (self , shape , ** kwargs ):
326326 """create an array filled with ones"""
327- return torch .ones (shape , device = "cuda" )
327+ return torch .ones (shape , device = "cuda" , ** kwargs )
328328
329- def zeros (self , shape ):
329+ def zeros (self , shape , ** kwargs ):
330330 """create an array filled with zeros"""
331- return torch .zeros (shape , device = "cuda" )
331+ return torch .zeros (shape , device = "cuda" , ** kwargs )
332332
333- def array (self , arr , dtype = None ):
333+ def array (self , arr , dtype = None , ** kwargs ):
334334 """create an array from an array-like sequence"""
335335 if dtype is None :
336336 dtype = torch .get_default_dtype ()
337337 if torch .is_tensor (arr ):
338- return arr .clone ().to (device = "cuda" , dtype = dtype )
339- return torch .tensor (arr , device = "cuda" , dtype = dtype )
338+ return arr .clone ().to (device = "cuda" , dtype = dtype , ** kwargs )
339+ return torch .tensor (arr , device = "cuda" , dtype = dtype , ** kwargs )
340340
341341 # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
342342 # The same warning applies here.
0 commit comments