Skip to content

Lightning place model inputs and model to different devices #20276

@Kamichanw

Description

@Kamichanw

Bug description

In the following code snippet, lmm is a class inherited from nn.Module which is a wrapper class huggingface model and processor.

class ICVModel(pl.LightningModule):
    def __init__(self, lmm, icv_encoder: torch.nn.Module) -> None:
        super().__init__()
        self.lmm = lmm
        self.lmm.requires_grad_(False)
        self.icv_encoder = icv_encoder
        self.eos_token = self.lmm.processor.tokenizer.eos_token

    def forward(self, ice_texts, query_texts, answers, images):
        query_answer = [
            query + answer + self.eos_token
            for query, answer in zip(query_texts, answers)
        ]
        query_images = [img[-setting.num_image_in_query :] for img in images]
        query_inputs = self.lmm.process_input(query_answer, query_images)
        query_outputs = self.lmm.model(
            **query_inputs,
            labels=query_inputs["input_ids"],
        )

However, a device mismatch error raised at

query_outputs = self.lmm.model(
        **query_inputs,
        labels=query_inputs["input_ids"],
)

I printed device of inputs.pixel_values.device, self.device, self.lmm.device outside of lmm.model.forward, then I got

rank[0]: cpu cuda:0 cuda:0
rank[1]: cpu cuda:1 cuda:1

In Idefics (self.lmm.model) forward process, when I printed inputs.pixel_values.device and self.device, I got

rank[0]: cuda:0 cuda:0
rank[1]: cuda:0 cuda:1

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

Full trace stack:

[rank1]:   File "/home/jyc/ICLTestbed/dev/train.py", line 103, in <module>
[rank1]:     main()
[rank1]:   File "/home/jyc/ICLTestbed/dev/train.py", line 72, in main
[rank1]:     trainer.fit(
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
[rank1]:     call._call_and_handle_interrupt(
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank1]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank1]:     return function(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
[rank1]:     self._run(model, ckpt_path=ckpt_path)
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
[rank1]:     results = self._run_stage()
[rank1]:               ^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
[rank1]:     self.fit_loop.run()
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank1]:     self.advance()
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank1]:     self.epoch_loop.run(self._data_fetcher)
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank1]:     self.advance(data_fetcher)
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
[rank1]:     batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
[rank1]:     self._optimizer_step(batch_idx, closure)
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
[rank1]:     call._call_lightning_module_hook(
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 167, in _call_lightning_module_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1306, in optimizer_step
[rank1]:     optimizer.step(closure=optimizer_closure)
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
[rank1]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/ddp.py", line 270, in optimizer_step
[rank1]:     optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank1]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
[rank1]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank1]:     closure_result = closure()
[rank1]:                      ^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
[rank1]:     self._result = self.closure(*args, **kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 129, in closure
[rank1]:     step_output = self._step_fn()
[rank1]:                   ^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 317, in _training_step
[rank1]:     training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank1]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 389, in training_step
[rank1]:     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank1]:     wrapper_output = wrapper_module(*args, **kwargs)
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank1]:     loss = self.module(*inputs, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank1]:     out = method(*_args, **_kwargs)
[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/ICLTestbed/dev/icv_model.py", line 89, in training_step
[rank1]:     loss_dict = self(**batch)
[rank1]:                 ^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/ICLTestbed/dev/icv_model.py", line 42, in forward
[rank1]:     query_outputs = self.lmm.model(
[rank1]:                     ^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/modeling_idefics.py", line 1493, in forward
[rank1]:     outputs = self.model(
[rank1]:               ^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/modeling_idefics.py", line 1181, in forward
[rank1]:     image_hidden_states = self.vision_model(
[rank1]:                           ^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/vision.py", line 467, in forward
[rank1]:     hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/vision.py", line 147, in forward
[rank1]:     patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 460, in forward
[rank1]:     return self._conv_forward(input, self.weight, self.bias)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
[rank1]:     return F.conv2d(input, weight, bias, self.stride,
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)

Environment

No response

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions