|
3 | 3 | import numpy as np |
4 | 4 | import torch |
5 | 5 |
|
| 6 | +__all__ = ["to_tensor", "to_device", "tensor_one_hot"] |
6 | 7 |
|
7 | | -def ndarray_to_tensor( |
8 | | - array: np.ndarray, in_dim_format: str, out_dim_format: str |
9 | | -) -> torch.Tensor: |
10 | | - """Convert img (H, W)|(H, W, C)|(B, H, W, C) to a tensor. |
11 | 8 |
|
12 | | - Parameters |
13 | | - ---------- |
14 | | - array : np.ndarray |
15 | | - Numpy matrix. Shape: (H, W)|(H, W, C)|(B, H, W, C) |
16 | | - in_dim_format : str |
17 | | - The order of the dimensions in the input array. |
18 | | - One of: "HW", "HWC", "BHWC", "BCHW", "BHW" |
19 | | - out_dim_format : str |
20 | | - The order of the dimensions in the output tensor. |
21 | | - One of: "HW", "HWC", "BHWC", "BCHW", "BHW" |
22 | | -
|
23 | | - Returns |
24 | | - ------- |
25 | | - torch.Tensor: |
26 | | - Input converted to a batched tensor. Shape . |
27 | | -
|
28 | | - Raises |
29 | | - ------ |
30 | | - TypeError: |
31 | | - If input array is not np.ndarray. |
32 | | - ValueError: |
33 | | - If input array has wrong number of dimensions |
34 | | - ValueError: |
35 | | - If `in_dim_format` param is illegal. |
36 | | - ValueError: |
37 | | - If `out_dim_format` param is illegal. |
38 | | -
|
39 | | - """ |
40 | | - if not isinstance(array, np.ndarray): |
41 | | - raise TypeError(f"Input type: {type(array)} is not np.ndarray") |
42 | | - |
43 | | - if not 1 < len(array.shape) <= 4: |
44 | | - raise ValueError( |
45 | | - f"ndarray.shape {array.shape}, rank needs to be between [2, 4]" |
46 | | - ) |
47 | | - |
48 | | - dim_types = ("HW", "HWC", "BHWC", "BCHW", "BHW") |
49 | | - if in_dim_format not in dim_types: |
50 | | - raise ValueError( |
51 | | - f"Illegal `in_dim_format`. Got {in_dim_format}. Allowed: {dim_types}" |
52 | | - ) |
53 | | - |
54 | | - if out_dim_format not in dim_types: |
55 | | - raise ValueError( |
56 | | - f"Illegal `out_dim_format`. Got {out_dim_format}. Allowed: {dim_types}" |
57 | | - ) |
58 | | - |
59 | | - if not len(array.shape) == len(in_dim_format): |
60 | | - raise ValueError( |
61 | | - f"Mismatching input dimensions. Input Shape: {array.shape}. while " |
62 | | - f"`in_dim_format` is set to: {in_dim_format}." |
63 | | - ) |
64 | | - |
65 | | - if in_dim_format in ("HW", "BHW"): |
66 | | - if out_dim_format in ("HWC", "BHWC", "BCHW"): |
67 | | - array = array[..., None] |
68 | | - |
69 | | - if in_dim_format in ("HW", "HWC"): |
70 | | - if out_dim_format in ("BHWC", "BCHW", "BHW"): |
71 | | - array = array[None, ...] |
72 | | - |
73 | | - if ( |
74 | | - len(array.shape) == 4 |
75 | | - and in_dim_format in ("BHWC", "HWC", "HW") |
76 | | - and out_dim_format == "BCHW" |
77 | | - ): |
78 | | - array = array.transpose(0, 3, 1, 2) |
79 | | - |
80 | | - if len(array.shape) == 4 and in_dim_format == "BCHW" and out_dim_format == "BHWC": |
81 | | - array = array.transpose(0, 2, 3, 1) |
82 | | - |
83 | | - return torch.from_numpy(array) |
84 | | - |
85 | | - |
86 | | -def tensor_to_ndarray( |
87 | | - tensor: torch.Tensor, in_dim_format: str, out_dim_format: str |
88 | | -) -> np.ndarray: |
89 | | - """Convert a tensor into a numpy ndarray. |
90 | | -
|
91 | | - Parameters |
92 | | - ---------- |
93 | | - tensor : torch.Tensor |
94 | | - The input tensor. Shape: (B, H, W)|(B, C, H, W) |
95 | | - in_dim_format : str |
96 | | - The order of the dimensions in the input array. |
97 | | - One of: "BCHW", "BHW" |
98 | | - out_dim_format : str |
99 | | - The order of the dimensions in the output tensor. |
100 | | - One of: "HW", "HWC", "BHWC", "BHW" |
101 | | -
|
102 | | - Returns |
103 | | - ------- |
104 | | - np.ndarray: |
105 | | - An ndarray. Shape(B, H, W, C)|(B, H, W)|(H, W, C)|(H, W) |
106 | | -
|
107 | | - Raises |
108 | | - ------ |
109 | | - TypeError: |
110 | | - If input array is not torch.Tensor. |
111 | | - ValueError: |
112 | | - If input array has wrong number of dimensions |
113 | | - ValueError: |
114 | | - If `in_dim_format` param is illegal. |
115 | | - ValueError: |
116 | | - If `out_dim_format` param is illegal. |
117 | | - """ |
118 | | - if not isinstance(tensor, torch.Tensor): |
119 | | - raise TypeError(f"Input type: {type(tensor)} is not torch.Tensor") |
120 | | - |
121 | | - if not 3 <= tensor.dim() <= 4: |
122 | | - raise ValueError( |
123 | | - "The input tensor needs to have shape (B, H, W) or (B, C, H, W). " |
124 | | - f"Got: {tensor.shape}", |
125 | | - ) |
126 | | - |
127 | | - in_dim_types = ("BCHW", "BHW") |
128 | | - if in_dim_format not in in_dim_types: |
129 | | - raise ValueError( |
130 | | - f"Illegal `in_dim_format`. Got {in_dim_format}. Allowed: {in_dim_types}" |
131 | | - ) |
132 | | - |
133 | | - out_dim_types = ("BCHW", "BHWC", "BHW", "HWC", "HW") |
134 | | - if out_dim_format not in out_dim_types: |
135 | | - raise ValueError( |
136 | | - f"Illegal `out_dim_format`. Got {out_dim_format}. Allowed: {out_dim_types}" |
137 | | - ) |
| 9 | +def to_tensor(x: np.ndarray) -> torch.Tensor: |
| 10 | + """Convert numpy array to torch tensor. Expects HW(C) format.""" |
| 11 | + if x.ndim == 2: |
| 12 | + x = x[:, :, None] |
138 | 13 |
|
139 | | - # detach and bring to cpu |
140 | | - array = tensor.detach() |
141 | | - if tensor.is_cuda: |
142 | | - array = array.cpu() |
143 | | - |
144 | | - array = array.numpy() |
145 | | - if array.ndim == 4 and out_dim_format != "BCHW": |
146 | | - array = array.transpose(0, 2, 3, 1) # (B, H, W, C) |
147 | | - |
148 | | - if out_dim_format == "HW" and array.ndim == 4: |
149 | | - array = array.squeeze() |
150 | | - |
151 | | - if out_dim_format == "HWC" and array.ndim == 4: |
152 | | - try: |
153 | | - array = array.squeeze(axis=0) |
154 | | - except Exception: |
155 | | - pass |
156 | | - |
157 | | - if out_dim_format == "BHW" and array.ndim == 4: |
158 | | - try: |
159 | | - array = array.squeeze(axis=-1) |
160 | | - except Exception: |
161 | | - pass |
162 | | - |
163 | | - if in_dim_format == "BHW": |
164 | | - if out_dim_format == "BHWC": |
165 | | - array = array[..., None] |
166 | | - elif out_dim_format == "HW": |
167 | | - array = array.squeeze() |
168 | | - |
169 | | - return array |
| 14 | + return torch.from_numpy(x.transpose((2, 0, 1))).contiguous() |
170 | 15 |
|
171 | 16 |
|
172 | 17 | def to_device(tensor: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: |
|
0 commit comments