Best practice: where to count number of samples per class #15199
Unanswered
mfoglio
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I am looking for best practices as I know that there could be multiple ways to solve the problem. I just like to understand if there is a lightning approach that should be preferred.
In my
LightningModule
I initialize aCrossEntropyLoss
with specificweight
to handle imbalanced classes:torch.nn.CrossEntropyLoss(weight=my_weights)
. The weight for each class is defined as the1 / number_of_samples_in_the_class
.In order to do this, I need to supply my
LightningModule
instance with the number of samples per class. However, usually you would load the data (and therefore count the number of samples per class in thesetup
function of theLightningDataModule
instance. So here's the problem: usually, when you initialize theLightningModule
you haven't loaded yet the data.Example:
As possible solutions, I could manually call
my_data_module.setup()
or I could compute the number of samples inside the__init__
function ofMyDataModule
but both ways seem not to follow torch lightning philosophy. What would be the cleanest way to solve this?Thank you!
Beta Was this translation helpful? Give feedback.
All reactions