Skip to content

Commit 46b00a7

Browse files
authored
Make device and dtype required (#9168)
1 parent 39dd3a6 commit 46b00a7

File tree

1 file changed

+3
-13
lines changed

1 file changed

+3
-13
lines changed

pytorch_lightning/utilities/apply_func.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import numpy as np
2424
import torch
2525

26-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2726
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE
2827

2928
if _TORCHTEXT_AVAILABLE:
@@ -36,22 +35,16 @@
3635

3736

3837
def to_dtype_tensor(
39-
value: Union[int, float, List[Union[int, float]]],
40-
dtype: Optional[torch.dtype] = None,
41-
device: Union[str, torch.device] = None,
38+
value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device]
4239
) -> torch.Tensor:
43-
if device is None:
44-
raise MisconfigurationException("device (torch.device) should be provided.")
4540
return torch.tensor(value, dtype=dtype, device=device)
4641

4742

48-
def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
49-
if device is None:
50-
raise MisconfigurationException("device (torch.device) should be provided.")
43+
def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> torch.Tensor:
5144
return torch.from_numpy(value).to(device)
5245

5346

54-
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
47+
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], torch.Tensor]]] = [
5548
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
5649
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
5750
(int, partial(to_dtype_tensor, dtype=torch.int)),
@@ -276,9 +269,6 @@ def batch_to(data: Any) -> Any:
276269

277270

278271
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
279-
if device is None:
280-
raise MisconfigurationException("`torch.device` should be provided.")
281-
282272
for src_dtype, conversion_func in CONVERSION_DTYPES:
283273
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
284274

0 commit comments

Comments
 (0)