|
23 | 23 | import numpy as np |
24 | 24 | import torch |
25 | 25 |
|
26 | | -from pytorch_lightning.utilities.exceptions import MisconfigurationException |
27 | 26 | from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE |
28 | 27 |
|
29 | 28 | if _TORCHTEXT_AVAILABLE: |
|
36 | 35 |
|
37 | 36 |
|
38 | 37 | def to_dtype_tensor( |
39 | | - value: Union[int, float, List[Union[int, float]]], |
40 | | - dtype: Optional[torch.dtype] = None, |
41 | | - device: Union[str, torch.device] = None, |
| 38 | + value: Union[int, float, List[Union[int, float]]], dtype: torch.dtype, device: Union[str, torch.device] |
42 | 39 | ) -> torch.Tensor: |
43 | | - if device is None: |
44 | | - raise MisconfigurationException("device (torch.device) should be provided.") |
45 | 40 | return torch.tensor(value, dtype=dtype, device=device) |
46 | 41 |
|
47 | 42 |
|
48 | | -def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: |
49 | | - if device is None: |
50 | | - raise MisconfigurationException("device (torch.device) should be provided.") |
| 43 | +def from_numpy(value: np.ndarray, device: Union[str, torch.device]) -> torch.Tensor: |
51 | 44 | return torch.from_numpy(value).to(device) |
52 | 45 |
|
53 | 46 |
|
54 | | -CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [ |
| 47 | +CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], torch.Tensor]]] = [ |
55 | 48 | # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group |
56 | 49 | (bool, partial(to_dtype_tensor, dtype=torch.uint8)), |
57 | 50 | (int, partial(to_dtype_tensor, dtype=torch.int)), |
@@ -276,9 +269,6 @@ def batch_to(data: Any) -> Any: |
276 | 269 |
|
277 | 270 |
|
278 | 271 | def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any: |
279 | | - if device is None: |
280 | | - raise MisconfigurationException("`torch.device` should be provided.") |
281 | | - |
282 | 272 | for src_dtype, conversion_func in CONVERSION_DTYPES: |
283 | 273 | data = apply_to_collection(data, src_dtype, conversion_func, device=device) |
284 | 274 |
|
|
0 commit comments