Skip to content

Commit a650c31

Browse files
authored
[misc] update swift patch_conv3d (#7320)
1 parent 2c19674 commit a650c31

File tree

7 files changed

+20
-17
lines changed

7 files changed

+20
-17
lines changed

docs/source/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ qwen2_5_omni除了包含qwen2_5_vl和qwen2_audio的模型特定参数外,还
800800
- 提示:ms-swift只对thinker部分进行微调,建议设置为False以降低显存占用(只创建thinker部分的模型结构)。
801801

802802
### qwen3_vl
803-
参数含义与`qwen_vl_utils>=0.0.14`库中的含义一致,可以查看[这里](https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24)。通过传入以下环境变量,可以修改该库的全局变量默认值。
803+
参数含义与`qwen_vl_utils>=0.0.14`库中的含义一致,可以查看[这里](https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24)。通过传入以下环境变量,可以修改该库的全局变量默认值。(也兼容使用`qwen2_5_vl`的环境变量,例如:`MAX_PIXELS``VIDEO_MAX_PIXELS`,会做自动转换。)
804804

805805
- SPATIAL_MERGE_SIZE: 默认为2。
806806
- IMAGE_MIN_TOKEN_NUM: 默认为`4`,代表一张图片最小图像tokens的个数。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,8 @@ qwen2_5_omni not only includes the model-specific parameters of qwen2_5_vl and q
825825

826826

827827
### qwen3_vl
828-
The parameter meanings are the same as in the `qwen_vl_utils>=0.0.14` library — see here: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24. By passing the following environment variables you can override the library's global default values:
828+
The parameter meanings are the same as in the `qwen_vl_utils>=0.0.14` library — see here: https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L24. By passing the following environment variables you can override the library's global default values: (It is also compatible with environment variables used by `qwen2_5_vl`, such as: `MAX_PIXELS`, `VIDEO_MAX_PIXELS`, and will perform automatic conversion.)
829+
829830

830831
- SPATIAL_MERGE_SIZE: default 2.
831832
- IMAGE_MIN_TOKEN_NUM: default `4`, denotes the minimum number of image tokens per image.

examples/infer/demo_embedding.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44

55
if __name__ == '__main__':
66
engine = PtEngine(
7-
'Qwen/Qwen3-Embedding-4B',
8-
task_type='embedding',
9-
torch_dtype=torch.float16,
10-
attn_implementation='flash_attention_2')
7+
'Qwen/Qwen3-Embedding-4B', task_type='embedding', torch_dtype=torch.float16, attn_impl='flash_attention_2')
118

129
infer_requests = [
1310
InferRequest(messages=[

examples/infer/demo_reranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
'Qwen/Qwen3-Reranker-4B',
88
task_type='generative_reranker',
99
torch_dtype=torch.float16,
10-
attn_implementation='flash_attention_2')
10+
attn_impl='flash_attention_2')
1111

1212
infer_request = InferRequest(
1313
messages=[{

swift/llm/infer/infer_engine/pt_engine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device
2222
from swift.plugin import Metric
2323
from swift.tuners import Swift
24+
from swift.utils import get_last_valid_indices
2425
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
2526
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
2627
EmbeddingResponseData, RequestConfig, random_uuid)
@@ -349,7 +350,13 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req
349350
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
350351
token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token)
351352
token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token)
352-
batch_scores = logits[:, -1, :]
353+
attention_mask = inputs.get('attention_mask')
354+
if attention_mask is None:
355+
batch_scores = logits[:, -1, :]
356+
else:
357+
last_valid_indices = get_last_valid_indices(attention_mask)
358+
batch_indices = torch.arange(attention_mask.shape[0], device=logits.device)
359+
batch_scores = logits[batch_indices, last_valid_indices, :]
353360
true_vector = batch_scores[:, token_true_id]
354361
false_vector = batch_scores[:, token_false_id]
355362
batch_scores = torch.stack([false_vector, true_vector], dim=1).float()

swift/llm/model/patcher.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import accelerate
1010
import torch
1111
import torch.nn as nn
12+
import torch.nn.functional as F
1213
import transformers
1314
from accelerate.utils import find_device
1415
from packaging import version
@@ -92,15 +93,9 @@ def _output_embedding_hook(module, args, kwargs, output):
9293
if attention_mask is None:
9394
attention_mask = output.get('attention_mask', None)
9495
hidden_states = output.logits
95-
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
96-
if left_padding:
97-
embeddings = hidden_states[:, -1]
98-
else:
99-
sequence_lengths = attention_mask.sum(dim=1) - 1
100-
batch_size = hidden_states.shape[0]
101-
embeddings = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths]
102-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
103-
96+
sequence_lengths = get_last_valid_indices(attention_mask)
97+
embeddings = hidden_states[torch.arange(hidden_states.shape[0], device=hidden_states.device), sequence_lengths]
98+
embeddings = F.normalize(embeddings, p=2, dim=1)
10499
return {
105100
'last_hidden_state': embeddings.contiguous(),
106101
}

swift/llm/model/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn.functional as F
1010
from accelerate.utils import find_device
1111
from modelscope.hub.utils.utils import get_cache_dir
12+
from packaging import version
1213
from torch import nn
1314
from transformers import PretrainedConfig
1415
from transformers.utils import strtobool
@@ -549,6 +550,8 @@ def _patch_conv3d():
549550
nn.Conv3d._original_forward = nn.Conv3d.forward
550551

551552
def forward(self, x):
553+
if version.parse(torch.__version__) < version.parse('2.9.0'):
554+
return self._original_forward(x)
552555
if any(s != k for s, k in zip(self.stride, self.kernel_size)) or any(p != 0 for p in self.padding) or any(
553556
d != 1 for d in self.dilation) or self.groups != 1:
554557
raise NotImplementedError(

0 commit comments

Comments
 (0)