Skip to content

_update_dataloader improperly copies state of subclassed dataloader with attribute names that differ from __init__ parameters. #20265

@spenceforce

Description

@spenceforce

Bug description

_update_dataloader in lightning/fabric/utilities/data.py does not copy a dataloader, but instead infers the arguments passed to the dataloader and instantiates a new one. The inference of arguments in _get_dataloader_init_args_and_kwargs does not take into account custom logic in __init__ that cause attributes to have different names from the parameters and results in the new dataloader using default arguments instead of those passed to the original dataloader.

This problem arises during distributed training when a new dataloader is created.

In the code example, the dataloader takes an argument x=2 and assigns it to the attribute self._x. When a new dataloader is created, _get_dataloader_init_args_and_kwargs infers x used the default argument instead of 2. The log below shows the output of print statements in __init__ and __iter__. When the original dataloader is created self._x is 2, but when it is copied, self._x has the default argument None. The relevant print lines in the log are surrounded with **.

I assume there's a reason why deepcopy or some other copying mechanism isn't used?

Expected behavior

State of dataloader is copied as is.

Actual behavior

Dataloader is reinitialized with wrong arguments.

Thank you for this awesome software!

What version are you seeing the problem on?

v2.3, v2.4

How to reproduce the bug

import argparse
import sys
from pathlib import Path

import lightning as L
import torch
import torch.nn.functional as F


class MyDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, x=None, **kwargs):
        super().__init__(*args, **kwargs)
        self._x = x
        print("DummyDataLoader.__init__(): self._x is", self._x)

    def __iter__(self):
        print("DummyDataLoader.__iter__(): self._x is", self._x)

        for item in super().__iter__():
            yield item


class MyLinear(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1, "cuda")

    def forward(self, X):
        return self.linear(X)


class MyLightning(L.LightningModule):

    def __init__(self):
        super().__init__()
        self.linear = MyLinear()

    def forward(self, X):
        return self.linear(X)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self(X)
        loss = F.mse_loss(y, y_hat)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


def train_func():
    X = y = torch.arange(4, device="cuda", dtype=torch.float32).view(-1, 1)
    train_data = MyDataLoader(list(zip(X, y)), batch_size=4, x=2)

    model = MyLightning()
    model.to("cuda")
    trainer = L.Trainer(
        max_epochs=10,
        devices=1,
        num_nodes=1,
        accelerator="gpu",
        use_distributed_sampler=True,
        strategy="ddp",
        enable_checkpointing=False
    )
    trainer.fit(model, train_data)
    return


train_func()

Error messages and logs

**DummyDataLoader.__init__(): self._x is 2**
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/[email protected]/.anaconda3/envs/lightning_error2/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` \
has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, u\
nless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for pe\
rformance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type     | Params | Mode
--------------------------------------------
0 | linear | MyLinear | 2      | train
--------------------------------------------
2         Trainable params
0         Non-trainable params
2         Total params
0.000     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode
**DummyDataLoader.__init__(): self._x is None**
/home/[email protected]/.anaconda3/envs/lightning_error2/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which\
 may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/[email protected]/.anaconda3/envs/lightning_error2/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Tr\
ainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0:   0%|                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]\
**DummyDataLoader.__iter__(): self._x is None**
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.08it/s, v_num=15]\
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.02it/s, v_num=15]

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3090
    • available: True
    • version: 11.8
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.7
    • pytorch-lightning: 2.4.0
    • torch: 2.4.1
    • torchmetrics: 1.4.0.post0
  • Packages:
    • autocommand: 2.2.2
    • backports.tarfile: 1.2.0
    • brotli: 1.0.9
    • certifi: 2024.8.30
    • cffi: 1.16.0
    • charset-normalizer: 3.3.2
    • colorama: 0.4.6
    • filelock: 3.13.1
    • fsspec: 2024.9.0
    • h2: 4.1.0
    • hpack: 4.0.0
    • hyperframe: 6.0.1
    • idna: 3.8
    • importlib-metadata: 8.0.0
    • importlib-resources: 6.4.0
    • inflect: 7.3.1
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jinja2: 3.1.4
    • lightning: 2.4.0
    • lightning-utilities: 0.11.7
    • markupsafe: 2.1.3
    • mkl-fft: 1.3.10
    • mkl-random: 1.2.7
    • mkl-service: 2.4.0
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • networkx: 3.3
    • numpy: 1.26.4
    • ordered-set: 4.1.0
    • packaging: 24.1
    • pip: 24.2
    • platformdirs: 4.2.2
    • pycparser: 2.22
    • pysocks: 1.7.1
    • pytorch-lightning: 2.4.0
    • pyyaml: 6.0.1
    • requests: 2.32.3
    • setuptools: 72.1.0
    • sympy: 1.13.2
    • tomli: 2.0.1
    • torch: 2.4.1
    • torchmetrics: 1.4.0.post0
    • tqdm: 4.66.5
    • triton: 3.0.0
    • typeguard: 4.3.0
    • typing-extensions: 4.11.0
    • urllib3: 2.2.2
    • wheel: 0.43.0
    • zipp: 3.19.2
    • zstandard: 0.22.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.12.4
    • release: 4.18.0-553.16.1.el8_10.x86_64
    • version: Proposal for help #1 SMP Thu Aug 1 04:16:12 EDT 2024

More info

No response

cc @tchaton

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions