Skip to content

Commit 42df19c

Browse files
authored
Merge pull request #189 from kozistr/feature/loss-functions
[Feature] Implement loss functions
2 parents 6bd5277 + ba896c9 commit 42df19c

20 files changed

+1284
-48
lines changed

.github/pull_request_template.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ Remove this part when you open the PR
33

44
Here's a checklist before opening the Pull Request!
55

6-
1. PR Title convention : [Type of PR] [Summary] (e.g. [Feature] Implement AdamP optimizer)
6+
1. PR title convention : [Type of PR] [Summary] (e.g. [Feature] Implement AdamP optimizer)
77
2. Attach `as much information as possible you can`. It helps the reviewers a lot :)
88
3. Make sure the code is perfectly `runnable & compatible`.
99
4. If your PR is not ready yet, make your `PR` to `Draft PR`.
10-
5. Make sure `make check` before opening the `PR`.
10+
5. Make sure `make format & check` before opening the `PR`.
11+
6. Or you just call the maintainer to help to fix code-style & test cases.
1112
---
1213

1314
## Problem (Why?)

README.rst

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pytorch-optimizer
1616

1717
| **pytorch-optimizer** is optimizer & lr scheduler collections in PyTorch.
1818
| I just re-implemented (speed & memory tweaks, plug-ins) the algorithm while based on the original paper. Also, It includes useful and practical optimization ideas.
19-
| Currently, 57 optimizers, 6 lr schedulers are supported!
19+
| Currently, 57 optimizers, 6 lr schedulers, and 10 loss functions are supported!
2020
|
2121
| Highly inspired by `pytorch-optimizer <https://github.com/jettify/pytorch-optimizer>`__.
2222
@@ -240,6 +240,33 @@ You can check the supported learning rate schedulers with below code.
240240
| Chebyshev | *Acceleration via Fractal Learning Rate Schedules* | | `https://arxiv.org/abs/2103.01338 <https://arxiv.org/abs/2103.01338>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2021arXiv210301338A/exportcitation>`__ |
241241
+------------------+---------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
242242

243+
Supported Loss Function
244+
-----------------------
245+
246+
You can check the supported loss functions with below code.
247+
248+
::
249+
250+
from pytorch_optimizer import get_supported_loss_functions
251+
252+
supported_loss_functions = get_supported_loss_functions()
253+
254+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
255+
| Loss Functions | Description | Official Code | Paper | Citation |
256+
+=====================+=========================================================================================================================+===================================================================================+===============================================================================================+======================================================================================================================+
257+
| Label Smoothing | *Rethinking the Inception Architecture for Computer Vision* | | `https://arxiv.org/abs/1512.00567 <https://arxiv.org/abs/1512.00567>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2015arXiv151200567S/exportcitation>`__ |
258+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
259+
| Focal | *Focal Loss for Dense Object Detection* | | `https://arxiv.org/abs/1708.02002 <https://arxiv.org/abs/1708.02002>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2017arXiv170802002L/exportcitation>`__ |
260+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
261+
| Focal Cosine | *Data-Efficient Deep Learning Method for Image Classification Using Data Augmentation, Focal Cosine Loss, and Ensemble* | | `https://arxiv.org/abs/2007.07805 <https://arxiv.org/abs/2007.07805>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2020arXiv200707805K/exportcitation>`__ |
262+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
263+
| LDAM | *Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss* | `github <https://github.com/kaidic/LDAM-DRW>`__ | `https://arxiv.org/abs/1906.07413 <https://arxiv.org/abs/1906.07413>`__ | `cite <https://github.com/kaidic/LDAM-DRW#reference>`__ |
264+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
265+
| Jaccard (IOU) | *IoU Loss for 2D/3D Object Detection* | | `https://arxiv.org/abs/1908.03851 <https://arxiv.org/abs/1908.03851>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv190803851Z/exportcitation>`__ |
266+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
267+
| Bi-Tempered | *The Principle of Unchanged Optimality in Reinforcement Learning Generalization* | | `https://arxiv.org/abs/1906.03361 <https://arxiv.org/abs/1906.03361>`__ | `cite <https://ui.adsabs.harvard.edu/abs/2019arXiv190600336I/exportcitation>`__ |
268+
+---------------------+-------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------+
269+
243270
Useful Resources
244271
----------------
245272

docs/changelogs/v2.11.0.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
* [Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks](https://arxiv.org/abs/1806.06763)
77
* Implement LOMO optimizer (#188)
88
* [Full Parameter Fine-tuning for Large Language Models with Limited Resources](https://arxiv.org/abs/2306.09782)
9+
* Implement loss functions (#189)
10+
* BCELoss
11+
* BCEFocalLoss
12+
* FocalLoss : [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
13+
* FocalCosineLoss : [Data-Efficient Deep Learning Method for Image Classification Using Data Augmentation, Focal Cosine Loss, and Ensemble](https://arxiv.org/abs/2007.07805)
14+
* DiceLoss : [Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations](https://arxiv.org/abs/1707.03237v3)
15+
* LDAMLoss : [Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss](https://arxiv.org/abs/1906.07413)
16+
* JaccardLoss
17+
* BiTemperedLogisticLoss : [Robust Bi-Tempered Logistic Loss Based on Bregman Divergences](https://arxiv.org/abs/1906.03361)
918

1019
### Diff
1120

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Contents
2121
base_api
2222
optimizer_api
2323
scheduler_api
24+
loss_api
2425
util_api
2526

2627
Indices and tables

docs/loss_api.rst

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
Loss Functions
2+
==============
3+
4+
.. _BCELoss:
5+
6+
BCELoss
7+
-------
8+
9+
.. autoclass:: pytorch_optimizer.BCELoss
10+
:members:
11+
12+
.. _BCEFocal:
13+
14+
BCEFocal
15+
--------
16+
17+
.. autoclass:: pytorch_optimizer.BCEFocal
18+
:members:
19+
20+
.. _FocalLoss:
21+
22+
FocalLoss
23+
---------
24+
25+
.. autoclass:: pytorch_optimizer.FocalLoss
26+
:members:
27+
28+
.. _FocalCosineLoss:
29+
30+
FocalCosineLoss
31+
---------------
32+
33+
.. autoclass:: pytorch_optimizer.FocalCosineLoss
34+
:members:
35+
36+
.. _SoftF1Loss:
37+
38+
SoftF1Loss
39+
----------
40+
41+
.. autoclass:: pytorch_optimizer.SoftF1Loss
42+
:members:
43+
44+
.. _DiceLoss:
45+
46+
DiceLoss
47+
--------
48+
49+
.. autoclass:: pytorch_optimizer.DiceLoss
50+
:members:
51+
52+
.. _LDAMLoss:
53+
54+
LDAMLoss
55+
--------
56+
57+
.. autoclass:: pytorch_optimizer.LDAMLoss
58+
:members:
59+
60+
.. _JaccardLoss:
61+
62+
JaccardLoss
63+
-----------
64+
65+
.. autoclass:: pytorch_optimizer.JaccardLoss
66+
:members:
67+
68+
.. _BiTemperedLogisticLoss:
69+
70+
BiTemperedLogisticLoss
71+
----------------------
72+
73+
.. autoclass:: pytorch_optimizer.BiTemperedLogisticLoss
74+
:members:
75+
76+
.. _BinaryBiTemperedLogisticLoss:
77+
78+
BinaryBiTemperedLogisticLoss
79+
----------------------------
80+
81+
.. autoclass:: pytorch_optimizer.BinaryBiTemperedLogisticLoss
82+
:members:

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.10.1"
4-
description = "optimizer & lr scheduler collections in PyTorch"
3+
version = "2.11.0"
4+
description = "optimizer & lr scheduler & objective function collections in PyTorch"
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
77
maintainers = ["kozistr <[email protected]>"]
@@ -16,7 +16,7 @@ keywords = [
1616
"DAdaptSGD", "DiffGrad", "Fromage", "Gravity", "GSAM", "LARS", "Lamb", "Lion", "LOMO", "Lookahead", "MADGRAD",
1717
"MSVAG", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger",
1818
"Ranger21", "RotoGrad", "SAM", "SGDP", "SGDW", "SignSGD", "SM3", "SopihaH", "SRMM", "SWATS", "ScalableShampoo",
19-
"Shampoo", "Yogi",
19+
"Shampoo", "Yogi", "BCE", "BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered",
2020
]
2121
classifiers = [
2222
"License :: OSI Approved :: Apache Software License",
@@ -94,6 +94,8 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
9494
target-version = "py311"
9595

9696
[tool.ruff.per-file-ignores]
97+
"./pytorch_optimizer/__init__.py" = ["F401"]
98+
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]
9799
"./hubconf.py" = ["D", "INP001"]
98100
"./tests/__init__.py" = ["D"]
99101
"./tests/constants.py" = ["D"]
@@ -104,13 +106,11 @@ target-version = "py311"
104106
"./tests/test_optimizers.py" = ["D", "S101"]
105107
"./tests/test_optimizer_parameters.py" = ["D", "S101"]
106108
"./tests/test_general_optimizer_parameters.py" = ["D", "S101"]
107-
"./tests/test_load_optimizers.py" = ["D", "S101"]
108-
"./tests/test_load_lr_schedulers.py" = ["D", "S101"]
109109
"./tests/test_lr_schedulers.py" = ["D", "S101"]
110110
"./tests/test_lr_scheduler_parameters.py" = ["D", "S101"]
111111
"./tests/test_create_optimizer.py" = ["D"]
112-
"./pytorch_optimizer/__init__.py" = ["F401"]
113-
"./pytorch_optimizer/lr_scheduler/__init__.py" = ["F401"]
112+
"./tests/test_loss_functions.py" = ["D", "S101"]
113+
"./tests/test_load_modules.py" = ["D", "S101"]
114114

115115
[tool.pytest.ini_options]
116116
testpaths = "tests"

pytorch_optimizer/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
from torch import nn
55

66
from pytorch_optimizer.base.types import OPTIMIZER, PARAMETERS, SCHEDULER
7+
from pytorch_optimizer.loss.bi_tempered import BinaryBiTemperedLogisticLoss, BiTemperedLogisticLoss
8+
from pytorch_optimizer.loss.cross_entropy import BCELoss
9+
from pytorch_optimizer.loss.dice import DiceLoss, soft_dice_score
10+
from pytorch_optimizer.loss.f1 import SoftF1Loss
11+
from pytorch_optimizer.loss.focal import BCEFocalLoss, FocalCosineLoss, FocalLoss
12+
from pytorch_optimizer.loss.jaccard import JaccardLoss, soft_jaccard_score
13+
from pytorch_optimizer.loss.ldam import LDAMLoss
714
from pytorch_optimizer.lr_scheduler import (
815
ConstantLR,
916
CosineAnnealingLR,
@@ -177,6 +184,22 @@
177184
str(lr_scheduler.__name__).lower(): lr_scheduler for lr_scheduler in LR_SCHEDULER_LIST
178185
}
179186

187+
LOSS_FUNCTION_LIST: List = [
188+
BCELoss,
189+
BCEFocalLoss,
190+
FocalLoss,
191+
SoftF1Loss,
192+
DiceLoss,
193+
LDAMLoss,
194+
FocalCosineLoss,
195+
JaccardLoss,
196+
BiTemperedLogisticLoss,
197+
BinaryBiTemperedLogisticLoss,
198+
]
199+
LOSS_FUNCTIONS: Dict[str, nn.Module] = {
200+
str(loss_function.__name__).lower(): loss_function for loss_function in LOSS_FUNCTION_LIST
201+
}
202+
180203

181204
def load_optimizer(optimizer: str) -> OPTIMIZER:
182205
optimizer: str = optimizer.lower()
@@ -245,3 +268,7 @@ def get_supported_optimizers() -> List[OPTIMIZER]:
245268

246269
def get_supported_lr_schedulers() -> List[SCHEDULER]:
247270
return LR_SCHEDULER_LIST
271+
272+
273+
def get_supported_loss_functions() -> List[nn.Module]:
274+
return LOSS_FUNCTION_LIST

pytorch_optimizer/base/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
SCHEDULER = Type[_LRScheduler]
1515

1616
HUTCHINSON_G = Literal['gaussian', 'rademacher']
17+
CLASS_MODE = Literal['binary', 'multiclass', 'multilabel']

0 commit comments

Comments
 (0)