|
37 | 37 |
|
38 | 38 | def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: |
39 | 39 | """ |
40 | | - For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]` |
41 | | - for `num_classes` N number of classes. |
| 40 | + For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th |
| 41 | + dimension has the "one-hot" format, i.e., it has a total length of `num_classes`, |
| 42 | + with a one and `num_class-1` zeros. |
| 43 | + Note that this will include the background label, thus a binary mask should be treated as having two classes. |
| 44 | +
|
| 45 | + Args: |
| 46 | + labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be |
| 47 | + converted into integers `labels.long()`. |
| 48 | + num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to |
| 49 | + `num_classes` from `1`. |
| 50 | + dtype: the data type of the output one_hot label. |
| 51 | + dim: the dimension to be converted to `num_classes` channels from `1` channel. |
42 | 52 |
|
43 | 53 | Example: |
44 | 54 |
|
45 | | - For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. |
46 | | - Note that this will include the background label, thus a binary mask should be treated as having 2 classes. |
| 55 | + For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]` |
| 56 | + when `num_classes=N` number of classes and `dim=1`. |
| 57 | +
|
| 58 | + .. code-block:: python |
| 59 | +
|
| 60 | + from monai.networks.utils import one_hot |
| 61 | + import torch |
| 62 | +
|
| 63 | + a = torch.randint(0, 2, size=(1, 2, 2, 2)) |
| 64 | + out = one_hot(a, num_classes=2, dim=0) |
| 65 | + print(out.shape) # torch.Size([2, 2, 2, 2]) |
| 66 | +
|
| 67 | + a = torch.randint(0, 2, size=(2, 1, 2, 2, 2)) |
| 68 | + out = one_hot(a, num_classes=2, dim=1) |
| 69 | + print(out.shape) # torch.Size([2, 2, 2, 2, 2]) |
| 70 | +
|
47 | 71 | """ |
48 | 72 | if labels.dim() == 0: |
49 | 73 | # if no channel dim, add it |
|
0 commit comments