Skip to content

Commit 224d263

Browse files
committed
fix: create simpler to_tensor function. rm tensor_to_ndarray. Not needed.
1 parent d65ada4 commit 224d263

File tree

1 file changed

+6
-161
lines changed

1 file changed

+6
-161
lines changed

cellseg_models_pytorch/utils/tensor_utils.py

Lines changed: 6 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -3,170 +3,15 @@
33
import numpy as np
44
import torch
55

6+
__all__ = ["to_tensor", "to_device", "tensor_one_hot"]
67

7-
def ndarray_to_tensor(
8-
array: np.ndarray, in_dim_format: str, out_dim_format: str
9-
) -> torch.Tensor:
10-
"""Convert img (H, W)|(H, W, C)|(B, H, W, C) to a tensor.
118

12-
Parameters
13-
----------
14-
array : np.ndarray
15-
Numpy matrix. Shape: (H, W)|(H, W, C)|(B, H, W, C)
16-
in_dim_format : str
17-
The order of the dimensions in the input array.
18-
One of: "HW", "HWC", "BHWC", "BCHW", "BHW"
19-
out_dim_format : str
20-
The order of the dimensions in the output tensor.
21-
One of: "HW", "HWC", "BHWC", "BCHW", "BHW"
22-
23-
Returns
24-
-------
25-
torch.Tensor:
26-
Input converted to a batched tensor. Shape .
27-
28-
Raises
29-
------
30-
TypeError:
31-
If input array is not np.ndarray.
32-
ValueError:
33-
If input array has wrong number of dimensions
34-
ValueError:
35-
If `in_dim_format` param is illegal.
36-
ValueError:
37-
If `out_dim_format` param is illegal.
38-
39-
"""
40-
if not isinstance(array, np.ndarray):
41-
raise TypeError(f"Input type: {type(array)} is not np.ndarray")
42-
43-
if not 1 < len(array.shape) <= 4:
44-
raise ValueError(
45-
f"ndarray.shape {array.shape}, rank needs to be between [2, 4]"
46-
)
47-
48-
dim_types = ("HW", "HWC", "BHWC", "BCHW", "BHW")
49-
if in_dim_format not in dim_types:
50-
raise ValueError(
51-
f"Illegal `in_dim_format`. Got {in_dim_format}. Allowed: {dim_types}"
52-
)
53-
54-
if out_dim_format not in dim_types:
55-
raise ValueError(
56-
f"Illegal `out_dim_format`. Got {out_dim_format}. Allowed: {dim_types}"
57-
)
58-
59-
if not len(array.shape) == len(in_dim_format):
60-
raise ValueError(
61-
f"Mismatching input dimensions. Input Shape: {array.shape}. while "
62-
f"`in_dim_format` is set to: {in_dim_format}."
63-
)
64-
65-
if in_dim_format in ("HW", "BHW"):
66-
if out_dim_format in ("HWC", "BHWC", "BCHW"):
67-
array = array[..., None]
68-
69-
if in_dim_format in ("HW", "HWC"):
70-
if out_dim_format in ("BHWC", "BCHW", "BHW"):
71-
array = array[None, ...]
72-
73-
if (
74-
len(array.shape) == 4
75-
and in_dim_format in ("BHWC", "HWC", "HW")
76-
and out_dim_format == "BCHW"
77-
):
78-
array = array.transpose(0, 3, 1, 2)
79-
80-
if len(array.shape) == 4 and in_dim_format == "BCHW" and out_dim_format == "BHWC":
81-
array = array.transpose(0, 2, 3, 1)
82-
83-
return torch.from_numpy(array)
84-
85-
86-
def tensor_to_ndarray(
87-
tensor: torch.Tensor, in_dim_format: str, out_dim_format: str
88-
) -> np.ndarray:
89-
"""Convert a tensor into a numpy ndarray.
90-
91-
Parameters
92-
----------
93-
tensor : torch.Tensor
94-
The input tensor. Shape: (B, H, W)|(B, C, H, W)
95-
in_dim_format : str
96-
The order of the dimensions in the input array.
97-
One of: "BCHW", "BHW"
98-
out_dim_format : str
99-
The order of the dimensions in the output tensor.
100-
One of: "HW", "HWC", "BHWC", "BHW"
101-
102-
Returns
103-
-------
104-
np.ndarray:
105-
An ndarray. Shape(B, H, W, C)|(B, H, W)|(H, W, C)|(H, W)
106-
107-
Raises
108-
------
109-
TypeError:
110-
If input array is not torch.Tensor.
111-
ValueError:
112-
If input array has wrong number of dimensions
113-
ValueError:
114-
If `in_dim_format` param is illegal.
115-
ValueError:
116-
If `out_dim_format` param is illegal.
117-
"""
118-
if not isinstance(tensor, torch.Tensor):
119-
raise TypeError(f"Input type: {type(tensor)} is not torch.Tensor")
120-
121-
if not 3 <= tensor.dim() <= 4:
122-
raise ValueError(
123-
"The input tensor needs to have shape (B, H, W) or (B, C, H, W). "
124-
f"Got: {tensor.shape}",
125-
)
126-
127-
in_dim_types = ("BCHW", "BHW")
128-
if in_dim_format not in in_dim_types:
129-
raise ValueError(
130-
f"Illegal `in_dim_format`. Got {in_dim_format}. Allowed: {in_dim_types}"
131-
)
132-
133-
out_dim_types = ("BCHW", "BHWC", "BHW", "HWC", "HW")
134-
if out_dim_format not in out_dim_types:
135-
raise ValueError(
136-
f"Illegal `out_dim_format`. Got {out_dim_format}. Allowed: {out_dim_types}"
137-
)
9+
def to_tensor(x: np.ndarray) -> torch.Tensor:
10+
"""Convert numpy array to torch tensor. Expects HW(C) format."""
11+
if x.ndim == 2:
12+
x = x[:, :, None]
13813

139-
# detach and bring to cpu
140-
array = tensor.detach()
141-
if tensor.is_cuda:
142-
array = array.cpu()
143-
144-
array = array.numpy()
145-
if array.ndim == 4 and out_dim_format != "BCHW":
146-
array = array.transpose(0, 2, 3, 1) # (B, H, W, C)
147-
148-
if out_dim_format == "HW" and array.ndim == 4:
149-
array = array.squeeze()
150-
151-
if out_dim_format == "HWC" and array.ndim == 4:
152-
try:
153-
array = array.squeeze(axis=0)
154-
except Exception:
155-
pass
156-
157-
if out_dim_format == "BHW" and array.ndim == 4:
158-
try:
159-
array = array.squeeze(axis=-1)
160-
except Exception:
161-
pass
162-
163-
if in_dim_format == "BHW":
164-
if out_dim_format == "BHWC":
165-
array = array[..., None]
166-
elif out_dim_format == "HW":
167-
array = array.squeeze()
168-
169-
return array
14+
return torch.from_numpy(x.transpose((2, 0, 1))).contiguous()
17015

17116

17217
def to_device(tensor: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:

0 commit comments

Comments
 (0)