Skip to content

Commit 6e5b58a

Browse files
authored
fix Internvl-int8 device map (#937)
1 parent 86188e2 commit 6e5b58a

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
- `--use_loss_scale`: 默认为`False`. 生效时会将Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果.
124124
- `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集.
125125
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
126+
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None
127+
126128

127129

128130
### FSDP参数

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
- `--use_loss_scale`: Default is `False`. When taking effect, strengthens loss weight of some Agent fields (Action/Action Input part) to enhance CoT, has no effect in regular SFT scenarios.
124124
- `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets.
125125
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
126-
126+
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
127127

128128
### FSDP Parameters
129129

swift/llm/utils/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ def __post_init__(self) -> None:
10191019
self.load_from_ckpt_dir()
10201020
else:
10211021
assert self.load_dataset_config is False, 'You need to first set `--load_args_from_ckpt_dir true`.'
1022+
self.handle_compatibility()
10221023
self._handle_dataset_sample()
10231024
self.handle_custom_register()
10241025
self.handle_custom_dataset_info()

swift/llm/utils/model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,13 +2609,15 @@ def fix_internvl_inplace_bug(model) -> None:
26092609

26102610
embedding = model.language_model.get_input_embeddings()
26112611
if not hasattr(embedding, '__old_forward'): # Avoid double patching
2612-
if hasattr(embedding, '_old_forward'): # device_map
2613-
__old_forward = embedding._old_forward
2614-
embedding._old_forward = lambda *args, **kwargs: __old_forward(*args, **kwargs).requires_grad_(True).clone()
2615-
else:
2616-
__old_forward = embedding.forward
2617-
embedding.forward = lambda *args, **kwargs: __old_forward(*args, **kwargs).requires_grad_(True).clone()
2618-
embedding.__old_forward = __old_forward
2612+
old_forward = embedding.forward
2613+
2614+
@wraps(old_forward)
2615+
def _new_forward(*args, **kwargs):
2616+
device = args[0].device
2617+
return old_forward(*args, **kwargs).requires_grad_(True).clone().to(device)
2618+
2619+
embedding.__old_forward = old_forward
2620+
embedding.forward = _new_forward
26192621

26202622

26212623
@register_model(

0 commit comments

Comments
 (0)