Skip to content

Commit 6542c54

Browse files
authored
fix alpaca (#2771)
1 parent f2c0a49 commit 6542c54

File tree

7 files changed

+33
-38
lines changed

7 files changed

+33
-38
lines changed

swift/llm/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download,
1919
HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys,
2020
ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup,
21-
Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth)
21+
Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth,
22+
git_clone_github)
2223
from .dataset import (AlpacaPreprocessor, ResponsePreprocessor, MessagesPreprocessor, AutoPreprocessor,
2324
DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor,
2425
LazyLLMDataset, ConstantLengthDataset, standard_keys, load_dataset, DATASET_TYPE,
@@ -51,7 +52,7 @@
5152
'ModelInfo', 'ModelMeta', 'ModelKeys', 'register_model_arch', 'MultiModelKeys', 'ModelArch',
5253
'MODEL_ARCH_MAPPING', 'get_model_arch', 'get_model_info_meta', 'get_model_name', 'register_model',
5354
'ModelGroup', 'Model', 'get_model_tokenizer_with_flash_attn', 'get_model_tokenizer_multimodal',
54-
'load_by_unsloth'
55+
'load_by_unsloth', 'git_clone_github'
5556
],
5657
'dataset': [
5758
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING',

swift/llm/dataset/dataset/llm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99
from ..register import DatasetMeta, SubsetDataset, register_dataset
1010

1111

12-
def _concat_inst_inp_alpaca_zh(inst: str, inp: str) -> str:
13-
if inp.startswith('输入:'):
14-
inp = inp[3:]
15-
return f'{inst}\n{inp}'
12+
class AlpacaZhPreprocessor(AlpacaPreprocessor):
13+
14+
@classmethod
15+
def concat_inst_input(cls, instruction, input_):
16+
if input_ and input_.startswith('输入:'):
17+
input_ = input_[3:]
18+
return super().concat_inst_input(instruction, input_)
1619

1720

1821
register_dataset(
1922
DatasetMeta(
2023
ms_dataset_id='AI-ModelScope/alpaca-gpt4-data-zh',
2124
hf_dataset_id='llm-wizard/alpaca-gpt4-data-zh',
22-
preprocess_func=AlpacaPreprocessor(concat_inst_input=_concat_inst_inp_alpaca_zh),
25+
preprocess_func=AlpacaZhPreprocessor(),
2326
tags=['chat', 'general', '🔥'],
2427
))
2528

swift/llm/dataset/preprocessor/core.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -312,34 +312,22 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
312312

313313
class AlpacaPreprocessor(ResponsePreprocessor):
314314

315-
def __init__(self,
316-
*,
317-
concat_inst_input: Union[Callable[[str, str], str]] = '\n',
318-
columns_mapping: Optional[Dict[str, str]] = None,
319-
**kwargs) -> None:
320-
"""Alpaca format preprocessor
321-
322-
Args:
323-
concat_inst_input: The concat sep between instruction and input
324-
"""
325-
super().__init__(columns_mapping=columns_mapping, **kwargs)
326-
self.concat_inst_input = concat_inst_input
315+
@classmethod
316+
def concat_inst_input(cls, instruction, input_):
317+
if instruction and input_:
318+
query = f'{instruction}\n{input_}'
319+
else:
320+
query = instruction or input_
321+
assert isinstance(query, str), f'query: {query}'
322+
return query
327323

328324
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
329325
instruction = row.pop('instruction', None)
330326
input_ = row.pop('input', None)
331327
output = row.pop('output', None)
332328
if output is not None:
333329
row['response'] = output
334-
335-
if instruction is not None or input_ is not None:
336-
instruction = instruction or ''
337-
input_ = input_ or ''
338-
if isinstance(self.concat_inst_input, str):
339-
query = instruction + self.concat_inst_input + input_
340-
else:
341-
query = self.concat_inst_input(instruction, input_)
342-
row['query'] = query
330+
row['query'] = self.concat_inst_input(instruction, input_)
343331
return super().preprocess(row)
344332

345333

swift/llm/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer,
77
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head,
88
load_by_unsloth, register_model)
9-
from .utils import HfConfigFactory, ModelInfo, safe_snapshot_download
9+
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download

swift/llm/model/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def git_clone_github(github_url: str,
274274
local_repo_name: Optional[str] = None,
275275
branch: Optional[str] = None,
276276
commit_hash: Optional[str] = None) -> str:
277+
if github_url.endswith('.git'):
278+
github_url = github_url[:-4]
277279
git_cache_dir = os.path.join(get_cache_dir(), '_github')
278280
os.makedirs(git_cache_dir, exist_ok=True)
279281
if local_repo_name is None:
@@ -282,8 +284,7 @@ def git_clone_github(github_url: str,
282284
local_repo_path = os.path.join(git_cache_dir, local_repo_name)
283285
with safe_ddp_context(hash_id=local_repo_path):
284286
if not os.path.exists(local_repo_path):
285-
if not github_url.endswith('.git'):
286-
github_url = f'{github_url}.git'
287+
github_url = f'{github_url}.git'
287288
command = ['git', '-C', git_cache_dir, 'clone', github_url, local_repo_name]
288289
command_str = f"git -C '{git_cache_dir}' clone '{github_url}' {local_repo_name}"
289290
if branch is not None:

tests/general/test_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ def test_sft():
1515
# _test_dataset(['AI-ModelScope/Duet-v0.5'])
1616
# _test_dataset(['swift/SlimOrca', 'swift/cosmopedia-100k'])
1717
# _test_dataset(['OmniData/Zhihu-KOL-More-Than-100-Upvotes'])
18-
_test_dataset(['OmniData/Zhihu-KOL'])
19-
# _test_dataset(['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#200'])
18+
# _test_dataset(['OmniData/Zhihu-KOL'])
19+
_test_dataset([
20+
'AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000',
21+
'AI-ModelScope/LongAlpaca-12k#1000'
22+
])
2023
# _test_dataset(['swift/Infinity-Instruct:all'])
2124
# _test_dataset(['swift/sharegpt:all'])
2225
# _test_dataset(['AI-ModelScope/sharegpt_gpt4:all'])

tests/general/test_stream.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
def test_local_dataset():
55
# please use git clone
6-
local_dataset = '/mnt/nas2/huangjintao.hjt/work/datasets/swift-sft-mixture:firefly#100'
7-
dataset = load_dataset(datasets=[local_dataset], streaming=True)[0]
8-
for i, x in enumerate(dataset):
9-
pass
10-
print(i, x)
6+
from swift.llm import git_clone_github
7+
model_dir = git_clone_github('https://www.modelscope.cn/datasets/swift/swift-sft-mixture.git')
8+
dataset = load_dataset(datasets=[f'{model_dir}:firefly'], streaming=True)[0]
9+
print(next(iter(dataset)))
1110

1211

1312
def test_hub_dataset():

0 commit comments

Comments
 (0)