Skip to content

Commit 3a4d3ae

Browse files
update imbalanced dataset comment
1 parent 0261a68 commit 3a4d3ae

File tree

1 file changed

+24
-9
lines changed
  • ML/Pytorch/Basics/Imbalanced_classes

1 file changed

+24
-9
lines changed

ML/Pytorch/Basics/Imbalanced_classes/main.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
11
"""
2-
This code is for dealing with imbalanced datasets in PyTorch. Imbalanced datasets are those where the number of samples in one or more classes is significantly lower than the number of samples in the other classes. This can be a problem because it can lead to a model that is biased towards the more common classes, which can result in poor performance on the less common classes.
3-
4-
To deal with imbalanced datasets, this code implements two methods: oversampling and class weighting.
5-
6-
Oversampling involves generating additional samples for the underrepresented classes, while class weighting involves assigning higher weights to the loss of samples in the underrepresented classes, so that the model pays more attention to them.
7-
8-
In this code, the get_loader function takes a root directory for a dataset and a batch size, and returns a PyTorch data loader. The data loader is used to iterate over the dataset in batches. The get_loader function first applies some transformations to the images in the dataset using the transforms module from torchvision. Then it calculates the class weights based on the number of samples in each class. It then creates a WeightedRandomSampler object, which is used to randomly select a batch of samples with a probability proportional to their weights. Finally, it creates the data loader using the dataset and the weighted random sampler.
9-
10-
The main function then uses the data loader to iterate over the dataset for 10 epochs, and counts the number of samples in each class. Finally, it prints the counts for each class.
2+
This code is for dealing with imbalanced datasets in PyTorch. Imbalanced datasets
3+
are those where the number of samples in one or more classes is significantly lower
4+
than the number of samples in the other classes. This can be a problem because it
5+
can lead to a model that is biased towards the more common classes, which can result
6+
in poor performance on the less common classes.
7+
8+
To deal with imbalanced datasets, this code implements two methods: oversampling and
9+
class weighting.
10+
11+
Oversampling involves generating additional samples for the underrepresented classes,
12+
while class weighting involves assigning higher weights to the loss of samples in the
13+
underrepresented classes, so that the model pays more attention to them.
14+
15+
In this code, the get_loader function takes a root directory for a dataset and a batch
16+
size, and returns a PyTorch data loader. The data loader is used to iterate over the
17+
dataset in batches. The get_loader function first applies some transformations to the
18+
images in the dataset using the transforms module from torchvision. Then it calculates
19+
the class weights based on the number of samples in each class. It then creates a
20+
WeightedRandomSampler object, which is used to randomly select a batch of samples with a
21+
probability proportional to their weights. Finally, it creates the data loader using the
22+
dataset and the weighted random sampler.
23+
24+
The main function then uses the data loader to iterate over the dataset for 10 epochs,
25+
and counts the number of samples in each class. Finally, it prints the counts for each class.
1126
1227
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
1328
* 2020-04-08: Initial coding

0 commit comments

Comments
 (0)