Skip to content

Commit 64cede0

Browse files
authored
fix windows (#2733)
1 parent 6e3fa6d commit 64cede0

File tree

7 files changed

+23
-9
lines changed

7 files changed

+23
-9
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ RLHF参数继承于[训练参数](#训练参数)
365365
- 🔥output_dir: 导出结果存储路径,默认为None
366366

367367
- 🔥quant_method: 可选为'gptq', 'awq',默认为None
368-
- quant_n_samples: gptq/awq的校验集抽样数,默认为256
368+
- quant_n_samples: gptq/awq的校验集抽样数,默认为128
369369
- max_length: 校准集的max_length, 默认值2048
370370
- quant_batch_size: 量化batch_size,默认为1
371371
- group_size: 量化group大小,默认为128

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum
367367
- 🔥output_dir: Path for storing export results, default is None.
368368

369369
- 🔥quant_method: Options are 'gptq' and 'awq', default is None.
370-
- quant_n_samples: Sampling size for the validation set in gptq/awq, default is 256.
370+
- quant_n_samples: Sampling size for the validation set in gptq/awq, default is 128.
371371
- max_length: Max length for the calibration set, default value is 2048.
372372
- quant_batch_size: Quantization batch size, default is 1.
373373
- group_size: Group size for quantization, default is 128.

examples/export/quantize/awq.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ swift export \
33
--model Qwen/Qwen2.5-1.5B-Instruct \
44
--dataset AI-ModelScope/alpaca-gpt4-data-zh#500 \
55
AI-ModelScope/alpaca-gpt4-data-en#500 \
6-
--quant_n_samples 256 \
6+
--quant_n_samples 128 \
77
--quant_batch_size 1 \
88
--max_length 2048 \
99
--quant_method awq \

examples/export/quantize/gptq.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ swift export \
55
--model Qwen/Qwen2.5-1.5B-Instruct \
66
--dataset AI-ModelScope/alpaca-gpt4-data-zh#500 \
77
AI-ModelScope/alpaca-gpt4-data-en#500 \
8-
--quant_n_samples 256 \
8+
--quant_n_samples 128 \
99
--quant_batch_size 1 \
1010
--max_length 2048 \
1111
--quant_method gptq \

swift/llm/argument/export_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ExportArguments(MergeArguments, BaseArguments):
3333

3434
# awq/gptq
3535
quant_method: Literal['awq', 'gptq', 'bnb'] = None
36-
quant_n_samples: int = 256
36+
quant_n_samples: int = 128
3737
max_length: int = 2048
3838
quant_batch_size: int = 1
3939
group_size: int = 128

swift/llm/model/register.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
3+
import platform
34
import re
45
from copy import deepcopy
56
from dataclasses import asdict, dataclass, field
@@ -333,9 +334,11 @@ def get_model_name(model_id_or_path: str) -> Optional[str]:
333334
model_id_or_path = model_id_or_path.rstrip('/')
334335
match_ = re.search('/models--.+?--(.+?)/snapshots/', model_id_or_path)
335336
if match_ is not None:
336-
model_name = match_.group(1)
337-
else:
338-
model_name = model_id_or_path.rsplit('/', 1)[-1]
337+
return match_.group(1)
338+
339+
model_name = model_id_or_path.rsplit('/', 1)[-1]
340+
if platform.system().lower() == 'windows':
341+
model_name = model_name.rsplit('\\', 1)[-1]
339342
# compat modelscope snapshot_download
340343
model_name = model_name.replace('___', '.')
341344
return model_name

tests/export/quant.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def test_vlm_quant(quant_method: Literal['gptq', 'awq'] = 'awq'):
2424
quant_method=quant_method))
2525

2626

27+
def test_audio_quant(quant_method: Literal['gptq', 'awq'] = 'awq'):
28+
from swift.llm import export_main, ExportArguments
29+
export_main(
30+
ExportArguments(
31+
model='Qwen/Qwen2-Audio-7B-Instruct',
32+
quant_bits=4,
33+
dataset=['speech_asr/speech_asr_aishell1_trainsets:validation#1000'],
34+
quant_method=quant_method))
35+
36+
2737
def test_vlm_bnb_quant():
2838
from swift.llm import export_main, ExportArguments, infer_main, InferArguments
2939
export_main(ExportArguments(model='Qwen/Qwen2-VL-7B-Instruct', quant_bits=4, quant_method='bnb'))
@@ -34,4 +44,5 @@ def test_vlm_bnb_quant():
3444
if __name__ == '__main__':
3545
# test_llm_quant('gptq')
3646
# test_vlm_quant('gptq')
37-
test_vlm_bnb_quant()
47+
test_audio_quant('gptq')
48+
# test_vlm_bnb_quant()

0 commit comments

Comments
 (0)