Skip to content

Commit 20f3779

Browse files
authored
Avoid nan loss when there are labels with no samples in the training data. (#12)
1 parent 1c71f91 commit 20f3779

File tree

7 files changed

+73
-46
lines changed

7 files changed

+73
-46
lines changed

.github/workflows/ci.yml

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,39 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
operating-system: [ubuntu-latest, windows-latest, macos-latest]
15-
python-version: [3.7, 3.8, 3.9]
16-
torch-version: [1.10.2, 1.11.0, 1.12.0]
14+
os: [ubuntu-latest, windows-latest, macos-latest]
15+
python-version: [3.9, "3.10"]
16+
torch-version: [1.13.1, 2.5.1]
1717
fail-fast: false
1818

1919
steps:
2020
- name: Checkout
21-
uses: actions/checkout@v2
21+
uses: actions/checkout@v4
2222

2323
- name: Set up Python
24-
uses: actions/setup-python@v2
24+
uses: actions/setup-python@v5
2525
with:
2626
python-version: ${{ matrix.python-version }}
2727

2828
- name: Restore Ubuntu cache
29-
uses: actions/cache@v1
30-
if: matrix.operating-system == 'ubuntu-latest'
29+
uses: actions/cache@v4
30+
if: matrix.os == 'ubuntu-latest'
3131
with:
3232
path: ~/.cache/pip
3333
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
3434
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
3535

3636
- name: Restore MacOS cache
37-
uses: actions/cache@v1
38-
if: matrix.operating-system == 'macos-latest'
37+
uses: actions/cache@v4
38+
if: matrix.os == 'macos-latest'
3939
with:
4040
path: ~/Library/Caches/pip
4141
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
4242
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
4343

4444
- name: Restore Windows cache
45-
uses: actions/cache@v1
46-
if: matrix.operating-system == 'windows-latest'
45+
uses: actions/cache@v4
46+
if: matrix.os == 'windows-latest'
4747
with:
4848
path: ~\AppData\Local\pip\Cache
4949
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
@@ -52,9 +52,14 @@ jobs:
5252
- name: Update pip
5353
run: python -m pip install --upgrade pip
5454

55+
- name: Install package in development mode
56+
run: pip install -e .[dev]
57+
58+
- name: Show installed packages
59+
run: pip list
60+
5561
- name: Lint with flake8, black and isort
5662
run: |
57-
pip install -e .[dev]
5863
# stop the build if there are Python syntax errors or undefined names
5964
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
6065
black . --check --config pyproject.toml
@@ -66,17 +71,23 @@ jobs:
6671
run: >
6772
pip install numpy
6873
69-
- name: Install PyTorch on Linux and Windows
74+
- name: Install PyTorch==1.13.1 on Linux and Windows
7075
if: >
71-
matrix.operating-system == 'ubuntu-latest' ||
72-
matrix.operating-system == 'windows-latest'
76+
(matrix.os == 'ubuntu-latest' ||
77+
matrix.os == 'windows-latest') &&
78+
matrix.torch-version == '1.13.1'
7379
run: >
7480
pip install torch==${{ matrix.torch-version }}+cpu
7581
-f https://download.pytorch.org/whl/torch_stable.html
7682
77-
- name: Install PyTorch on MacOS
78-
if: matrix.operating-system == 'macos-latest'
79-
run: pip install torch==${{ matrix.torch-version }}
83+
- name: Install PyTorch==2.5.1 on Linux and Windows
84+
if: >
85+
(matrix.os == 'ubuntu-latest' ||
86+
matrix.os == 'windows-latest') &&
87+
matrix.torch-version == '2.5.1'
88+
run: >
89+
pip install torch==${{ matrix.torch-version }}
90+
-f https://download.pytorch.org/whl/torch_stable.html
8091
8192
- name: Install balanced-loss package from local setup.py
8293
run: >

.github/workflows/package_testing.yml

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,39 @@ jobs:
1010

1111
strategy:
1212
matrix:
13-
operating-system: [ubuntu-latest, windows-latest, macos-latest]
14-
python-version: [3.7, 3.8, 3.9]
15-
torch-version: [1.10.2, 1.11.0, 1.12.0]
13+
os: [ubuntu-latest, windows-latest, macos-latest]
14+
python-version: [3.9, "3.10"]
15+
torch-version: [1.13.1, 2.5.1]
1616
fail-fast: false
1717

1818
steps:
1919
- name: Checkout
20-
uses: actions/checkout@v2
20+
uses: actions/checkout@v4
2121

2222
- name: Set up Python
23-
uses: actions/setup-python@v2
23+
uses: actions/setup-python@v5
2424
with:
2525
python-version: ${{ matrix.python-version }}
2626

2727
- name: Restore Ubuntu cache
28-
uses: actions/cache@v1
29-
if: matrix.operating-system == 'ubuntu-latest'
28+
uses: actions/cache@v4
29+
if: matrix.os == 'ubuntu-latest'
3030
with:
3131
path: ~/.cache/pip
3232
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
3333
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
3434

3535
- name: Restore MacOS cache
36-
uses: actions/cache@v1
37-
if: matrix.operating-system == 'macos-latest'
36+
uses: actions/cache@v4
37+
if: matrix.os == 'macos-latest'
3838
with:
3939
path: ~/Library/Caches/pip
4040
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
4141
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
4242

4343
- name: Restore Windows cache
44-
uses: actions/cache@v1
45-
if: matrix.operating-system == 'windows-latest'
44+
uses: actions/cache@v4
45+
if: matrix.os == 'windows-latest'
4646
with:
4747
path: ~\AppData\Local\pip\Cache
4848
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
@@ -55,16 +55,26 @@ jobs:
5555
run: >
5656
pip install numpy
5757
58-
- name: Install PyTorch on Linux and Windows
58+
- name: Install PyTorch==1.13.1 on Linux and Windows
5959
if: >
60-
matrix.operating-system == 'ubuntu-latest' ||
61-
matrix.operating-system == 'windows-latest'
60+
(matrix.os == 'ubuntu-latest' ||
61+
matrix.os == 'windows-latest') &&
62+
matrix.torch-version == '1.13.1'
6263
run: >
6364
pip install torch==${{ matrix.torch-version }}+cpu
6465
-f https://download.pytorch.org/whl/torch_stable.html
6566
67+
- name: Install PyTorch==2.5.1 on Linux and Windows
68+
if: >
69+
(matrix.os == 'ubuntu-latest' ||
70+
matrix.os == 'windows-latest') &&
71+
matrix.torch-version == '2.5.1'
72+
run: >
73+
pip install torch==${{ matrix.torch-version }}
74+
-f https://download.pytorch.org/whl/torch_stable.html
75+
6676
- name: Install PyTorch on MacOS
67-
if: matrix.operating-system == 'macos-latest'
77+
if: matrix.os == 'macos-latest'
6878
run: pip install torch==${{ matrix.torch-version }}
6979

7080
- name: Install latest balanced-loss package
@@ -74,4 +84,3 @@ jobs:
7484
- name: Unittest balanced-loss
7585
run: |
7686
python -m unittest
77-

.github/workflows/publish_pypi.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ jobs:
99
runs-on: ubuntu-latest
1010

1111
steps:
12-
- uses: actions/checkout@v2
12+
- uses: actions/checkout@v4
1313
- name: Set up Python
14-
uses: actions/setup-python@v2
14+
uses: actions/setup-python@v5
1515
with:
1616
python-version: '3.x'
1717
- name: Install dependencies

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ When training dataset labels are imbalanced, one thing to do is to balance the l
2424

2525
![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png)
2626

27-
2827
## Installation
2928

3029
```bash
@@ -134,6 +133,7 @@ What is the difference between this repo and vandit15's?
134133
- This repo implements loss functions as `torch.nn.Module`
135134
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
136135
- All typos and errors in vandit15's source are fixed
136+
- Continiously tested on PyTorch 1.13.1 and 2.5.1
137137

138138
## References
139139

balanced_loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .losses import Loss
22

3-
__version__ = "0.1.0"
3+
__version__ = "0.1.1"

balanced_loss/losses.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
fl_gamma=2,
4545
samples_per_class=None,
4646
class_balanced=False,
47+
safe: bool = False,
4748
):
4849
"""
4950
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
@@ -60,6 +61,7 @@ def __init__(
6061
samples_per_class: A python list of size [num_classes].
6162
Required if class_balance is True.
6263
class_balanced: bool. Whether to use class balanced loss.
64+
safe: bool. Whether to allow labels with no samples.
6365
Returns:
6466
Loss instance
6567
"""
@@ -73,12 +75,9 @@ def __init__(
7375
self.fl_gamma = fl_gamma
7476
self.samples_per_class = samples_per_class
7577
self.class_balanced = class_balanced
78+
self.safe = safe
7679

77-
def forward(
78-
self,
79-
logits: torch.tensor,
80-
labels: torch.tensor,
81-
):
80+
def forward(self, logits: torch.tensor, labels: torch.tensor):
8281
"""
8382
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
8483
Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
@@ -97,8 +96,16 @@ def forward(
9796

9897
if self.class_balanced:
9998
effective_num = 1.0 - np.power(self.beta, self.samples_per_class)
99+
# Avoid division by 0 error for test cases without all labels present.
100+
if self.safe:
101+
effective_num_classes = np.sum(effective_num != 0)
102+
effective_num[effective_num == 0] = np.inf
103+
104+
else:
105+
effective_num_classes = num_classes
106+
100107
weights = (1.0 - self.beta) / np.array(effective_num)
101-
weights = weights / np.sum(weights) * num_classes
108+
weights = weights / np.sum(weights) * effective_num_classes
102109
weights = torch.tensor(weights, device=logits.device).float()
103110

104111
if self.loss_type != "cross_entropy":

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_version():
3737
setuptools.setup(
3838
name="balanced-loss",
3939
version=get_version(),
40-
author="",
40+
author="fcakyon",
4141
license="MIT",
4242
description="Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.",
4343
long_description=get_long_description(),
@@ -54,9 +54,9 @@ def get_version():
5454
"Intended Audience :: Developers",
5555
"Intended Audience :: Science/Research",
5656
"Programming Language :: Python :: 3",
57-
"Programming Language :: Python :: 3.7",
5857
"Programming Language :: Python :: 3.8",
5958
"Programming Language :: Python :: 3.9",
59+
"Programming Language :: Python :: 3.10",
6060
"Topic :: Software Development :: Libraries",
6161
"Topic :: Software Development :: Libraries :: Python Modules",
6262
"Topic :: Education",

0 commit comments

Comments
 (0)