-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x
Description
Bug description
In the following code snippet, lmm
is a class inherited from nn.Module
which is a wrapper class huggingface model and processor.
class ICVModel(pl.LightningModule):
def __init__(self, lmm, icv_encoder: torch.nn.Module) -> None:
super().__init__()
self.lmm = lmm
self.lmm.requires_grad_(False)
self.icv_encoder = icv_encoder
self.eos_token = self.lmm.processor.tokenizer.eos_token
def forward(self, ice_texts, query_texts, answers, images):
query_answer = [
query + answer + self.eos_token
for query, answer in zip(query_texts, answers)
]
query_images = [img[-setting.num_image_in_query :] for img in images]
query_inputs = self.lmm.process_input(query_answer, query_images)
query_outputs = self.lmm.model(
**query_inputs,
labels=query_inputs["input_ids"],
)
However, a device mismatch error raised at
query_outputs = self.lmm.model(
**query_inputs,
labels=query_inputs["input_ids"],
)
I printed device of inputs.pixel_values.device
, self.device
, self.lmm.device
outside of lmm.model.forward
, then I got
rank[0]: cpu cuda:0 cuda:0
rank[1]: cpu cuda:1 cuda:1
In Idefics (self.lmm.model
) forward process, when I printed inputs.pixel_values.device
and self.device
, I got
rank[0]: cuda:0 cuda:0
rank[1]: cuda:0 cuda:1
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
Full trace stack:
[rank1]: File "/home/jyc/ICLTestbed/dev/train.py", line 103, in <module>
[rank1]: main()
[rank1]: File "/home/jyc/ICLTestbed/dev/train.py", line 72, in main
[rank1]: trainer.fit(
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
[rank1]: call._call_and_handle_interrupt(
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank1]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank1]: return function(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
[rank1]: self._run(model, ckpt_path=ckpt_path)
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
[rank1]: results = self._run_stage()
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
[rank1]: self.fit_loop.run()
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
[rank1]: self.advance()
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
[rank1]: self.epoch_loop.run(self._data_fetcher)
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
[rank1]: self.advance(data_fetcher)
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
[rank1]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
[rank1]: self._optimizer_step(batch_idx, closure)
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
[rank1]: call._call_lightning_module_hook(
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 167, in _call_lightning_module_hook
[rank1]: output = fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1306, in optimizer_step
[rank1]: optimizer.step(closure=optimizer_closure)
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step
[rank1]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/ddp.py", line 270, in optimizer_step
[rank1]: optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step
[rank1]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank1]: closure_result = closure()
[rank1]: ^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
[rank1]: self._result = self.closure(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 129, in closure
[rank1]: step_output = self._step_fn()
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 317, in _training_step
[rank1]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
[rank1]: output = fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 389, in training_step
[rank1]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
[rank1]: wrapper_output = wrapper_module(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 1899, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
[rank1]: out = method(*_args, **_kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/ICLTestbed/dev/icv_model.py", line 89, in training_step
[rank1]: loss_dict = self(**batch)
[rank1]: ^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/ICLTestbed/dev/icv_model.py", line 42, in forward
[rank1]: query_outputs = self.lmm.model(
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/modeling_idefics.py", line 1493, in forward
[rank1]: outputs = self.model(
[rank1]: ^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/modeling_idefics.py", line 1181, in forward
[rank1]: image_hidden_states = self.vision_model(
[rank1]: ^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/vision.py", line 467, in forward
[rank1]: hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/transformers/models/idefics/vision.py", line 147, in forward
[rank1]: patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
[rank1]: output = module._old_forward(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 460, in forward
[rank1]: return self._conv_forward(input, self.weight, self.bias)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/jyc/miniconda3/envs/icl/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
[rank1]: return F.conv2d(input, weight, bias, self.stride,
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)
Environment
No response
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x