-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
Setting dropout doesn't work as expected in a model that uses a pretrained transformer from Hugging Face (BertForSequenceClassification).
Expected result: Setting the dropout probability to 1 should prevent the model from learning.
Actual result: The model keeps learning, i.e. loss decreases each epoch.
Why it looks like a bug in Lightning:
- After running the PyTorch Lightning code, printing the model shows
Dropout(p=1, inplace=False)
, so the config setting is made correctly. - The PyTorch code produces the expected result, where the loss doesn't change from epoch to epoch.
What version are you seeing the problem on?
v2.5
How to reproduce the bug
PyTorch Lightning code:
import torch
from pytorch_lightning import LightningModule, Trainer
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
class BertClassifier(LightningModule):
def __init__(self, dropout_prob=1):
super().__init__()
self.config = BertConfig.from_pretrained('bert-base-uncased')
self.config.hidden_dropout_prob = dropout_prob
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=self.config)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
# Unpack the batch into proper format for the model
input_ids, attention_mask, labels = batch
# Create the inputs dictionary expected by the BERT model
inputs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
outputs = self.model(**inputs)
loss = outputs.loss
self.log('train_loss', loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-5)
# Training data
texts = ["red", "green", "blue", "hot", "warm", "cold"]
labels = [0, 0, 0, 1, 1, 1]
# Prepare data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoded_inputs = tokenizer(texts, return_tensors="pt", padding=True)
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
labels_tensor = torch.tensor(labels)
# Create dataset and dataloader
dataset = TensorDataset(input_ids, attention_mask, labels_tensor)
dataloader = DataLoader(dataset, batch_size=6)
# Training
model = BertClassifier(dropout_prob=1)
trainer = Trainer(max_epochs=5)
trainer.fit(model, dataloader)
PyTorch code:
import torch
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification
# Training data
texts = ["red", "green", "blue", "hot", "warm", "cold"]
labels = [0, 0, 0, 1, 1, 1]
# PyTorch implementation
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 1
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=config)
inputs = tokenizer(texts, return_tensors="pt")
inputs["labels"] = torch.tensor(labels)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model.train()
for epoch in range(5):
outputs = model(**inputs)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
Error messages and logs
Console output (progress bar) from PyTorch Lightning:
Epoch 0: 100%|████████| 1/1 [00:00<00:00, 1.44it/s, v_num=8, train_loss=0.713]
Epoch 1: 100%|████████| 1/1 [00:00<00:00, 1.94it/s, v_num=8, train_loss=0.618]
Epoch 2: 100%|████████| 1/1 [00:00<00:00, 2.04it/s, v_num=8, train_loss=0.581]
Epoch 3: 100%|████████| 1/1 [00:00<00:00, 1.91it/s, v_num=8, train_loss=0.556]
Epoch 4: 100%|████████| 1/1 [00:00<00:00, 2.03it/s, v_num=8, train_loss=0.528]
Console output from PyTorch:
Epoch: 1, Loss: 0.6931471824645996
Epoch: 2, Loss: 0.6931471824645996
Epoch: 3, Loss: 0.6931471824645996
Epoch: 4, Loss: 0.6931471824645996
Epoch: 5, Loss: 0.6931471824645996
Environment
Current environment
#- PyTorch Lightning Version: 2.5.0.post0
#- PyTorch Version: 2.6.0+cpu
#- Python version: 3.12.9
#- OS: Linux
#- CUDA/cuDNN version: NA
#- GPU models and configuration: NA
#- How you installed Lightning: pip
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x