Skip to content

Pickling error when using cuda in collate_fn #20469

@Richienb

Description

@Richienb

Bug description

  • When I use cuda within the collate_fn parameter of the dataloader to pre-process generated data in bulk, and num_workers > 0,
  • I am required to use the spawn_ddp strategy in the trainer
  • Then, I get this error:
Traceback (most recent call last):
  File "/home/myuser/myproject/scripts/../train.py", line 1139, in <module>
    trainer.fit(training, data)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 136, in launch
    process_context = mp.start_processes(
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 268, in start_processes
    idx, process, tf_name = start_process(i)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 263, in start_process
    process.start()
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/home/myuser/anaconda3/envs/myproject/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'TorchGraph.create_forward_hook.<locals>.after_forward_hook'
  • Removing the line wandb_logger.watch(training) fixes the problem

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 3.10.15
#- Python version (e.g., 3.12): 3.10
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.4/11.5
#- GPU models and configuration: 1xRTX 3090
#- How you installed Lightning(`conda`, `pip`, source): `conda`

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions