33import numpy as np
44import torch
55
6- __all__ = ["to_tensor" , "to_device" , "tensor_one_hot" ]
6+ __all__ = ["to_tensor" , "to_device" ]
77
88
99def to_tensor (x : np .ndarray ) -> torch .Tensor :
1010 """Convert numpy array to torch tensor. Expects HW(C) format."""
1111 if x .ndim == 2 :
12- x = x [:, :, None ]
13-
12+ return torch .from_numpy (x ).contiguous ()
1413 return torch .from_numpy (x .transpose ((2 , 0 , 1 ))).contiguous ()
1514
1615
@@ -34,42 +33,3 @@ def to_device(tensor: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
3433 tensor = tensor .cuda ()
3534
3635 return tensor
37-
38-
39- def tensor_one_hot (type_map : torch .Tensor , n_classes : int ) -> torch .Tensor :
40- """Convert a segmentation mask into one-hot-format.
41-
42- I.e. Takes in a segmentation mask of shape (B, H, W) and reshapes it
43- into a tensor of shape (B, C, H, W).
44-
45- Parameters
46- ----------
47- type_map : torch.Tensor
48- Multi-label Segmentation mask. Shape (B, H, W).
49- n_classes : int
50- Number of classes. (Zero-class included.)
51-
52- Returns
53- -------
54- torch.Tensor:
55- A one hot tensor. Shape: (B, C, H, W). Dtype: torch.FloatTensor.
56-
57- Raises
58- ------
59- TypeError: If input is not torch.int64.
60- """
61- if not type_map .dtype == torch .int64 :
62- raise TypeError (
63- f"""
64- Input `type_map` should have dtype: torch.int64. Got: { type_map .dtype } ."""
65- )
66-
67- one_hot = torch .zeros (
68- type_map .shape [0 ],
69- n_classes ,
70- * type_map .shape [1 :],
71- device = type_map .device ,
72- dtype = type_map .dtype ,
73- )
74-
75- return one_hot .scatter_ (dim = 1 , index = type_map .unsqueeze (1 ), value = 1.0 ) + 1e-7
0 commit comments