Skip to content

Commit ceaf4bd

Browse files
committed
fix: rm unnecessary tensor_one_hot. update to_tensor
1 parent 95089c2 commit ceaf4bd

File tree

1 file changed

+2
-42
lines changed

1 file changed

+2
-42
lines changed

cellseg_models_pytorch/utils/tensor_utils.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import numpy as np
44
import torch
55

6-
__all__ = ["to_tensor", "to_device", "tensor_one_hot"]
6+
__all__ = ["to_tensor", "to_device"]
77

88

99
def to_tensor(x: np.ndarray) -> torch.Tensor:
1010
"""Convert numpy array to torch tensor. Expects HW(C) format."""
1111
if x.ndim == 2:
12-
x = x[:, :, None]
13-
12+
return torch.from_numpy(x).contiguous()
1413
return torch.from_numpy(x.transpose((2, 0, 1))).contiguous()
1514

1615

@@ -34,42 +33,3 @@ def to_device(tensor: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
3433
tensor = tensor.cuda()
3534

3635
return tensor
37-
38-
39-
def tensor_one_hot(type_map: torch.Tensor, n_classes: int) -> torch.Tensor:
40-
"""Convert a segmentation mask into one-hot-format.
41-
42-
I.e. Takes in a segmentation mask of shape (B, H, W) and reshapes it
43-
into a tensor of shape (B, C, H, W).
44-
45-
Parameters
46-
----------
47-
type_map : torch.Tensor
48-
Multi-label Segmentation mask. Shape (B, H, W).
49-
n_classes : int
50-
Number of classes. (Zero-class included.)
51-
52-
Returns
53-
-------
54-
torch.Tensor:
55-
A one hot tensor. Shape: (B, C, H, W). Dtype: torch.FloatTensor.
56-
57-
Raises
58-
------
59-
TypeError: If input is not torch.int64.
60-
"""
61-
if not type_map.dtype == torch.int64:
62-
raise TypeError(
63-
f"""
64-
Input `type_map` should have dtype: torch.int64. Got: {type_map.dtype}."""
65-
)
66-
67-
one_hot = torch.zeros(
68-
type_map.shape[0],
69-
n_classes,
70-
*type_map.shape[1:],
71-
device=type_map.device,
72-
dtype=type_map.dtype,
73-
)
74-
75-
return one_hot.scatter_(dim=1, index=type_map.unsqueeze(1), value=1.0) + 1e-7

0 commit comments

Comments
 (0)