diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py index 9a763a157e89..9ff71673f7db 100644 --- a/vllm/transformers_utils/processors/ovis2_5.py +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -669,7 +669,7 @@ def preprocess_multidata( # grids/placeholder grid_t = 1 - grids = torch.tensor([[grid_t, Ty, Tx]] * B, device=flatten_patches.device) + grids = torch.tensor([[grid_t, Ty, Tx]] * B, device='cpu') visual_placeholders = [ self.construct_visual_placeholders([grid_t, Ty, Tx], is_video=False) for _ in range(B) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9b4ec5da64c9..922cc037d1a3 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1901,14 +1901,16 @@ def _prepare_prompt( "false").lower() in ("1", "true", "yes") if use_mediapipe: # With mediapipe path some tensors will already be on HPU, we only move to HPU if needed - for key in multi_modal_kwargs.keys(): - if hasattr(multi_modal_kwargs[key], "device" - ) and multi_modal_kwargs[key].device != self.device: - multi_modal_kwargs[key] = self.move_to_device( - multi_modal_kwargs[key]) + for key, value in multi_modal_kwargs.items(): + if key == "grids": + continue # keeping grids on CPU + if hasattr(value, "device") and value.device != self.device: + multi_modal_kwargs[key] = self.move_to_device(value) else: multi_modal_kwargs = MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device) + if "grids" in multi_modal_kwargs: # keeping grids on CPU + multi_modal_kwargs["grids"] = multi_modal_kwargs["grids"].cpu() return PreparePromptMetadata(input_tokens=input_tokens_tensor, input_positions=input_positions, @@ -2775,7 +2777,7 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args, num_image_tokens = int(image_h * image_w // (vit_cfg.hidden_stride**2)) image_grid_thw = torch.tensor([[1, image_h, image_w]], - dtype=torch.int64) + dtype=torch.int64, device='cpu') pixel_values = torch.randn(1, image_grid_thw[0].prod(),