Skip to content

Commit cfcce58

Browse files
authored
enhance one-hot documentation (#2521)
Signed-off-by: Wenqi Li <[email protected]>
1 parent 42a1125 commit cfcce58

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

monai/networks/utils.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,37 @@
3737

3838
def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
3939
"""
40-
For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
41-
for `num_classes` N number of classes.
40+
For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th
41+
dimension has the "one-hot" format, i.e., it has a total length of `num_classes`,
42+
with a one and `num_class-1` zeros.
43+
Note that this will include the background label, thus a binary mask should be treated as having two classes.
44+
45+
Args:
46+
labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be
47+
converted into integers `labels.long()`.
48+
num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
49+
`num_classes` from `1`.
50+
dtype: the data type of the output one_hot label.
51+
dim: the dimension to be converted to `num_classes` channels from `1` channel.
4252
4353
Example:
4454
45-
For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0.
46-
Note that this will include the background label, thus a binary mask should be treated as having 2 classes.
55+
For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
56+
when `num_classes=N` number of classes and `dim=1`.
57+
58+
.. code-block:: python
59+
60+
from monai.networks.utils import one_hot
61+
import torch
62+
63+
a = torch.randint(0, 2, size=(1, 2, 2, 2))
64+
out = one_hot(a, num_classes=2, dim=0)
65+
print(out.shape) # torch.Size([2, 2, 2, 2])
66+
67+
a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))
68+
out = one_hot(a, num_classes=2, dim=1)
69+
print(out.shape) # torch.Size([2, 2, 2, 2, 2])
70+
4771
"""
4872
if labels.dim() == 0:
4973
# if no channel dim, add it

0 commit comments

Comments
 (0)