Skip to content

Commit 1df3d89

Browse files
authored
新增MultiLabelMarginLoss (#7331)
1 parent 0c4a980 commit 1df3d89

File tree

6 files changed

+164
-2
lines changed

6 files changed

+164
-2
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
.. _cn_api_paddle_nn_MultiLabelMarginLoss:
2+
3+
MultiLabelMarginLoss
4+
-------------------------------
5+
6+
.. py:class:: paddle.nn.MultiLabelMarginLoss(reduction='mean', name=None)
7+
8+
创建一个 MultiLabelMarginLoss 的可调用类。通过计算输入 `input` 和 `label` 间的多类别多分类问题的 `hinge loss (margin-based loss)` 损失。
9+
10+
损失函数计算每一个 mini-batch 的 loss 按照下列公式计算
11+
12+
.. math::
13+
\text{loss}(input_i, label_i) = \frac{\sum_{j \in \text{valid_labels}} \sum_{k \neq \text{valid_labels}} \max(0, 1 - (input_i[\text{valid_labels}[j]] - input_i[k]))}{C}
14+
15+
其中 :math:`C` 是类别数量, :math:`\text{valid_labels}` 包含样本 :math:`i` 所有非负的标签索引(遇到第一个 -1 时停止),:math:`k` 遍历除了 :math:`\text{valid_labels}` 之外的所有类别索引。
16+
17+
该损失函数只考虑前面的非负标签值,允许不同样本具有不同数量的目标类别。
18+
19+
参数
20+
:::::::::
21+
- **reduction** (str,可选) - 指定应用于输出结果的计算方式,可选值有:``'none'``、``'mean'``、``'sum'``。默认为 ``'mean'``,计算 Loss 的均值;设置为 ``'sum'`` 时,计算 Loss 的总和;设置为 ``'none'`` 时,则返回原始 Loss。
22+
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
23+
24+
调用参数
25+
:::::::::
26+
- **input** (Tensor) - 数据类型是 float32、float64。
27+
- **label** (Tensor) - 标签的数据类型为 int32、int64。标签值应该是类别索引(非负值)和 -1 值。-1 值会被忽略并停止处理每个样本。
28+
29+
形状
30+
:::::::::
31+
- **input** (Tensor) - :math:`[N, C]`,其中 N 是 batch_size, C 是类别数量。
32+
- **label** (Tensor) - :math:`[N, C]`,与 input 形状相同。
33+
- **output** (Tensor) - 输出的 Tensor。如果 :attr:`reduction` 是 ``'none'``,则输出的维度为 :math:`[N]`。如果 :attr:`reduction` 是 ``'mean'`` 或 ``'sum'``,则输出的维度为 :math:`[]` 。
34+
35+
返回
36+
:::::::::
37+
返回计算 MultiLabelMarginLoss 的可调用对象。
38+
39+
代码示例
40+
:::::::::
41+
COPY-FROM: paddle.nn.MultiLabelMarginLoss

docs/api/paddle/nn/Overview_cn.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ Loss 层
281281
" :ref:`paddle.nn.TripletMarginWithDistanceLoss <cn_api_paddle_nn_TripletMarginWithDistanceLoss>` ", "TripletMarginWithDistanceLoss 层"
282282
" :ref:`paddle.nn.MultiLabelSoftMarginLoss <cn_api_paddle_nn_MultiLabelSoftMarginLoss>` ", "多标签 Hinge 损失层"
283283
" :ref:`paddle.nn.MultiMarginLoss <cn_api_paddle_nn_MultiMarginLoss>` ", "MultiMarginLoss 层"
284+
" :ref:`paddle.nn.MultiLabelMarginLoss <cn_api_paddle_nn_MultiLabelMarginLoss>` ", "MultiLabelMarginLoss 层"
284285
" :ref:`paddle.nn.AdaptiveLogSoftmaxWithLoss <cn_api_paddle_nn_AdaptiveLogSoftmaxWithLoss>` ", "自适应 logsoftmax 损失类"
285286

286287

@@ -523,6 +524,7 @@ Embedding 相关函数
523524
" :ref:`paddle.nn.functional.hinge_embedding_loss <cn_api_paddle_nn_functional_hinge_embedding_loss>` ", "计算输入 input 和标签 label(包含 1 和 -1) 间的 `hinge embedding loss` 损失"
524525
" :ref:`paddle.nn.functional.rnnt_loss <cn_api_paddle_nn_functional_rnnt_loss>` ", "计算 RNNT loss,也可以叫做 softmax with RNNT"
525526
" :ref:`paddle.nn.functional.multi_margin_loss <cn_api_paddle_nn_functional_multi_margin_loss>` ", "用于计算 multi margin loss 损失函数"
527+
" :ref:`paddle.nn.functional.multi_label_margin_loss <cn_api_paddle_nn_functional_multi_label_margin_loss>` ", "用于计算 multi label margin loss 损失函数"
526528
" :ref:`paddle.nn.functional.adaptive_log_softmax_with_loss <cn_api_paddle_nn_functional_adaptive_log_softmax_with_loss>` ", "自适应 logsoftmax 损失函数"
527529

528530

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
.. _cn_api_paddle_nn_functional_multi_label_margin_loss:
2+
3+
multi_label_margin_loss
4+
-------------------------------
5+
6+
.. py:function:: paddle.nn.functional.multi_label_margin_loss(input, label, reduction='mean', name=None)
7+
8+
计算输入 `input` 和 `label` 间的多类别多分类问题的 `hinge loss` 损失。
9+
10+
损失函数计算每一个 mini-batch 的 loss 按照下列公式计算
11+
12+
.. math::
13+
\text{loss}(input_i, label_i) = \frac{\sum_{j \in \text{valid_labels}} \sum_{k \neq \text{valid_labels}} \max(0, 1 - (input_i[\text{valid_labels}[j]] - input_i[k]))}{C}
14+
15+
其中 :math:`C` 是类别数量, :math:`\text{valid_labels}` 包含样本 :math:`i` 所有非负的标签索引(遇到第一个 -1 时停止),:math:`k` 遍历除了 :math:`\text{valid_labels}` 之外的所有类别索引。
16+
17+
该损失函数只考虑前面的非负标签值,允许不同样本具有不同数量的目标类别。
18+
19+
参数
20+
:::::::::
21+
- **input** (Tensor) - :math:`[N, C]`,其中 N 是 batch_size, `C` 是类别数量。数据类型是 float32、float64。
22+
- **label** (Tensor) - :math:`[N, C]`,与 input 形状相同。标签 ``label`` 的数据类型为 int32、int64。标签值应该是类别索引(非负值)和 -1 值。-1 值会被忽略并停止处理每个样本。
23+
- **reduction** (str,可选) - 指定应用于输出结果的计算方式,可选值有:``'none'``, ``'mean'``, ``'sum'``。默认为 ``'mean'``,计算 Loss 的均值;设置为 ``'sum'`` 时,计算 Loss 的总和;设置为 ``'none'`` 时,则返回原始 Loss。
24+
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
25+
26+
形状
27+
:::::::::
28+
- **input** (Tensor) - :math:`[N, C]`,其中 N 是 batch_size,`C` 是类别数量。数据类型是 float32、float64。
29+
- **label** (Tensor) - :math:`[N, C]`,与 input 形状相同,标签 ``label`` 的数据类型为 int32、int64。
30+
- **output** (Tensor) - 输出的 Tensor。如果 :attr:`reduction` 是 ``'none'``,则输出的维度为 :math:`[N]`,与 batch_size 相同。如果 :attr:`reduction` 是 ``'mean'`` 或 ``'sum'``,则输出的维度为 :math:`[]` 。
31+
32+
返回
33+
:::::::::
34+
返回计算的 Loss。
35+
36+
代码示例
37+
:::::::::
38+
COPY-FROM: paddle.nn.functional.multi_label_margin_loss
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## [ torch 参数更多 ]torch.nn.functional.multilabel_margin_loss
2+
3+
### [torch.nn.functional.multilabel\_margin\_loss](https://pytorch.org/docs/stable/generated/torch.nn.functional.multilabel_margin_loss.html)
4+
5+
```python
6+
torch.nn.functional.multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean')
7+
```
8+
9+
### [paddle.nn.functional.multi\_label\_margin\_loss](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/functional/multi_label_margin_loss_cn.html#multi-label-margin-loss)
10+
11+
```python
12+
paddle.nn.functional.multi_label_margin_loss(input, label, reduction='mean', name=None)
13+
```
14+
15+
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
16+
17+
### 参数映射
18+
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------ | ------------ | -- |
21+
| input | input | 输入 Tensor。 |
22+
| target | label | 标签 Tensor,仅参数名不一致。 |
23+
| size_average | - | PyTorch 已弃用, Paddle 无此参数,需要转写。 |
24+
| reduce | - | PyTorch 已弃用, Paddle 无此参数,需要转写。 |
25+
| reduction | reduction | 指定应用于输出结果的计算方式。 |
26+
27+
### 转写示例
28+
29+
#### size_average、reduce
30+
```python
31+
# PyTorch 的 size_average、reduce 参数转为 Paddle 的 reduction 参数
32+
if size_average is None:
33+
size_average = True
34+
if reduce is None:
35+
reduce = True
36+
37+
if size_average and reduce:
38+
reduction = 'mean'
39+
elif reduce:
40+
reduction = 'sum'
41+
else:
42+
reduction = 'none'
43+
```
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
## [torch 参数更多]torch.nn.MultiLabelMarginLoss
2+
3+
### [torch.nn.MultiLabelMarginLoss](https://pytorch.org/docs/stable/generated/torch.nn.MultiLabelMarginLoss.html#torch.nn.MultiLabelMarginLoss)
4+
5+
```python
6+
torch.nn.MultiLabelMarginLoss(size_average=None, reduce=None, reduction='mean')
7+
```
8+
9+
### [paddle.nn.MultiLabelMarginLoss](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/MultiLabelMarginLoss_cn.html)
10+
11+
```python
12+
paddle.nn.MultiLabelMarginLoss(reduction='mean', name=None)
13+
```
14+
15+
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
16+
17+
### 参数映射
18+
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------------ | ------------ | ---------------------------------------------- |
21+
| size_average | - | 已废弃,和 reduce 组合决定损失计算方式。 |
22+
| reduce | - | 已废弃,和 size_average 组合决定损失计算方式。 |
23+
| reduction | reduction | 指定应用于输出结果的计算方式。 |
24+
25+
### 转写示例
26+
27+
```python
28+
# PyTorch 的 size_average、reduce 参数转为 Paddle 的 reduction 参数
29+
if size_average is None:
30+
size_average = True
31+
if reduce is None:
32+
reduce = True
33+
34+
if size_average and reduce:
35+
reduction = 'mean'
36+
elif reduce:
37+
reduction = 'sum'
38+
else:
39+
reduction = 'none'
40+
```

docs/guides/model_convert/convert_from_pytorch/pytorch_api_mapping_cn.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,6 @@
622622
| NOT-IMPLEMENTED-ITEM(`torch.special.zeta`, https://pytorch.org/docs/stable/special.html#torch.special.zeta, 可新增,且框架底层有相关设计,成本低) |
623623
| NOT-IMPLEMENTED-ITEM(`torch.xpu.current_device`, https://pytorch.org/docs/stable/generated/torch.xpu.current_device.html#torch-xpu-current-device, 有对应相近功能但设计差异大无法映射,一般无需新增) |
624624
| NOT-IMPLEMENTED-ITEM(`torch.xpu.get_device_properties`, https://pytorch.org/docs/stable/generated/torch.xpu.get_device_properties.html#torch-xpu-get-device-properties, 有对应相近功能但设计差异大无法映射,一般无需新增) |
625-
| NOT-IMPLEMENTED-ITEM(`torch.nn.functional.multilabel_margin_loss`, https://pytorch.org/docs/stable/generated/torch.nn.functional.multilabel_margin_loss.html#torch-nn-functional-multilabel-margin-loss, 可新增,且框架底层有相关设计,成本低) |
626-
| NOT-IMPLEMENTED-ITEM(`torch.nn.functional.MultiLabelMarginLoss`, https://pytorch.org/docs/stable/generated/torch.nn.MultiLabelMarginLoss.html#torch.nn.MultiLabelMarginLoss, 可新增,且框架底层有相关设计,成本低) |
627625
| NOT-IMPLEMENTED-ITEM(`torch.gradient`, https://pytorch.org/docs/stable/generated/torch.gradient.html#torch-gradient, 可新增,且框架底层有相关设计,成本低) |
628626
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.sparse_resize_`, https://pytorch.org/docs/stable/generated/torch.Tensor.sparse_resize_.html#torch-tensor-sparse-resize, 可新增,且框架底层有相关设计,成本低) |
629627
| NOT-IMPLEMENTED-ITEM(`torch.autograd.profiler.profile`, https://pytorch.org/docs/stable/autograd.html#torch.autograd.profiler.profile, 可新增,但框架底层无相关设计,成本高) |

0 commit comments

Comments
 (0)