Skip to content

Model with dropout probability=1 keeps learning (PyTorch version doesn't) #20646

@jedick

Description

@jedick

Bug description

Setting dropout doesn't work as expected in a model that uses a pretrained transformer from Hugging Face (BertForSequenceClassification).

Expected result: Setting the dropout probability to 1 should prevent the model from learning.

Actual result: The model keeps learning, i.e. loss decreases each epoch.

Why it looks like a bug in Lightning:

  • After running the PyTorch Lightning code, printing the model shows Dropout(p=1, inplace=False), so the config setting is made correctly.
  • The PyTorch code produces the expected result, where the loss doesn't change from epoch to epoch.

What version are you seeing the problem on?

v2.5

How to reproduce the bug

PyTorch Lightning code:

import torch
from pytorch_lightning import LightningModule, Trainer
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset

class BertClassifier(LightningModule):
    def __init__(self, dropout_prob=1):
        super().__init__()
        self.config = BertConfig.from_pretrained('bert-base-uncased')
        self.config.hidden_dropout_prob = dropout_prob
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=self.config)
    
    def forward(self, **inputs):
        return self.model(**inputs)
    
    def training_step(self, batch, batch_idx):
        # Unpack the batch into proper format for the model
        input_ids, attention_mask, labels = batch
        
        # Create the inputs dictionary expected by the BERT model
        inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
        
        outputs = self.model(**inputs)
        loss = outputs.loss
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-5)

# Training data
texts = ["red", "green", "blue", "hot", "warm", "cold"]
labels = [0, 0, 0, 1, 1, 1]

# Prepare data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
encoded_inputs = tokenizer(texts, return_tensors="pt", padding=True)
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
labels_tensor = torch.tensor(labels)

# Create dataset and dataloader
dataset = TensorDataset(input_ids, attention_mask, labels_tensor)
dataloader = DataLoader(dataset, batch_size=6)

# Training
model = BertClassifier(dropout_prob=1)
trainer = Trainer(max_epochs=5)
trainer.fit(model, dataloader)

PyTorch code:

import torch
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification

# Training data
texts = ["red", "green", "blue", "hot", "warm", "cold"]
labels = [0, 0, 0, 1, 1, 1]

# PyTorch implementation
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 1
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', config=config)
inputs = tokenizer(texts, return_tensors="pt")
inputs["labels"] = torch.tensor(labels)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model.train()
for epoch in range(5):
    outputs = model(**inputs)
    loss = outputs.loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Error messages and logs

Console output (progress bar) from PyTorch Lightning:

Epoch 0: 100%|████████| 1/1 [00:00<00:00,  1.44it/s, v_num=8, train_loss=0.713]
Epoch 1: 100%|████████| 1/1 [00:00<00:00,  1.94it/s, v_num=8, train_loss=0.618]
Epoch 2: 100%|████████| 1/1 [00:00<00:00,  2.04it/s, v_num=8, train_loss=0.581]
Epoch 3: 100%|████████| 1/1 [00:00<00:00,  1.91it/s, v_num=8, train_loss=0.556]
Epoch 4: 100%|████████| 1/1 [00:00<00:00,  2.03it/s, v_num=8, train_loss=0.528]

Console output from PyTorch:

Epoch: 1, Loss: 0.6931471824645996
Epoch: 2, Loss: 0.6931471824645996
Epoch: 3, Loss: 0.6931471824645996
Epoch: 4, Loss: 0.6931471824645996
Epoch: 5, Loss: 0.6931471824645996

Environment

Current environment
#- PyTorch Lightning Version: 2.5.0.post0
#- PyTorch Version: 2.6.0+cpu
#- Python version: 3.12.9
#- OS: Linux
#- CUDA/cuDNN version: NA
#- GPU models and configuration: NA
#- How you installed Lightning: pip

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions