Skip to content

Commit a371aa3

Browse files
authored
fix qwen2_vl flash_attn deepspeed (#3484)
1 parent 2e02ee6 commit a371aa3

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
</p>
2727

2828
<p align="center">
29-
<a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
29+
<a href="https://arxiv.org/abs/2408.05517">Paper</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
3030
</p>
3131
<p align="center">
3232
<a href="https://swift2x-en.readthedocs.io/en/latest/">Swift2.x En Doc</a> &nbsp | &nbsp <a href="https://swift2x.readthedocs.io/zh-cn/latest/">Swift2.x中文文档</a> &nbsp

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
</p>
2828

2929
<p align="center">
30-
<a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
30+
<a href="https://arxiv.org/abs/2408.05517">论文</a> &nbsp | <a href="https://swift.readthedocs.io/en/latest/">Swift3.x En Doc</a> &nbsp | &nbsp <a href="https://swift.readthedocs.io/zh-cn/latest/">Swift3.x中文文档</a> &nbsp
3131
</p>
3232
<p align="center">
3333
<a href="https://swift2x-en.readthedocs.io/en/latest/">Swift2.x En Doc</a> &nbsp | &nbsp <a href="https://swift2x.readthedocs.io/zh-cn/latest/">Swift2.x中文文档</a> &nbsp

swift/llm/template/template/qwen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn.functional as F
88

9+
from swift.llm import to_device
910
from swift.utils import get_env_args, is_deepspeed_enabled
1011
from ..base import Template
1112
from ..constant import LLMTemplateType, MLLMTemplateType
@@ -284,9 +285,8 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
284285
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
285286
media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
286287
device = input_ids.device
287-
pixel_values = media_inputs['pixel_values'].to(device)
288-
289-
pixel_values = pixel_values.type(dtype)
288+
media_inputs = to_device(media_inputs, device)
289+
pixel_values = media_inputs['pixel_values'].type(dtype)
290290
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
291291
inputs_embeds += image_embeds.mean() * 0.
292292
else:

0 commit comments

Comments
 (0)