Skip to content

Commit f5b4585

Browse files
fix ui, and support DATASET_ENABLE_CACHE variable (#1319)
1 parent fd1dd26 commit f5b4585

File tree

6 files changed

+82
-44
lines changed

6 files changed

+82
-44
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,19 @@ The complete list of supported models and datasets can be found at [Supported Mo
624624
| Computing cards A10/A100, etc. | Support BF16 and FlashAttn |
625625
| Huawei Ascend NPU | |
626626

627+
### Environment variables
628+
629+
- DATASET_ENABLE_CACHE: Enable cache when preprocess dataset, you can use `1/True` or `0/False`, default `False`
630+
- WEBUI_SHARE: Share your web-ui, you can use `1/True` or `0/False`, default `False`
631+
- SWIFT_UI_LANG: web-ui language, you can use `en` or `zh`, default `zh`
632+
- WEBUI_SERVER: web-ui host ip,`0.0.0.0` for all routes,`127.0.0.1` for local network only. Default `127.0.0.1`
633+
- WEBUI_PORT: web-ui port
634+
- USE_HF: Use huggingface endpoint or ModelScope endpoint to download models and datasets. you can use `1/True` or `0/False`, default `False`
635+
- FORCE_REDOWNLOAD: Force to re-download the dataset
636+
637+
Other variables like `CUDA_VISIBLE_DEVICES` are also supported, which are not listed here.
638+
639+
627640
## 📃 Documentation
628641

629642
### Documentation Compiling

README_CN.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,19 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
621621
| 华为昇腾NPU | |
622622

623623

624+
### 环境变量
625+
626+
- DATASET_ENABLE_CACHE:在预处理数据集时启用缓存,您可以使用`1/True``0/False`,默认值为`False`
627+
- WEBUI_SHARE:共享web-ui,可以使用`1/True``0/False`,默认值为`False`
628+
- SWIFT_UI_LANG:web-ui语言,您可以使用`en``zh`,默认值为`zh`
629+
- WEBUI_SERVER:web-ui可访问的IP`0.0.0.0`表示所有路由,`127.0.0.1`仅用于本地网络。默认值为`127.0.0.1`
630+
- WEBUI_PORT:web-ui端口
631+
- USE_HF:使用huggingface endpoint或ModelScope endpoint下载模型和数据集。您可以使用`1/True``0/False`,默认值为`False`
632+
- FORCE_REDOWNLOAD:强制重新下载数据集
633+
634+
其他变量如`CUDA_VISIBLE_DEVICES`也支持,但未在此列出。
635+
636+
624637
## 📃文档
625638

626639
### 文档编译

swift/llm/utils/dataset.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
from tqdm.auto import tqdm
2020
from transformers.utils import strtobool
2121

22-
from swift.utils import (get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, safe_ddp_context,
23-
transform_jsonl_to_df)
22+
from swift.utils import get_logger, get_seed, is_dist, is_local_master, read_from_jsonl, transform_jsonl_to_df
2423
from swift.utils.torch_utils import _find_local_mac
2524
from .media import MediaCache, MediaTag
2625
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
2726
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
2827
TextGenerationPreprocessor, preprocess_sharegpt)
2928
from .utils import download_dataset
3029

30+
dataset_enable_cache = strtobool(os.environ.get('DATASET_ENABLE_CACHE', 'False'))
31+
3132

3233
def _update_fingerprint_mac(*args, **kwargs):
3334
mac = _find_local_mac().replace(':', '')
@@ -378,7 +379,7 @@ def _post_preprocess(
378379
train_sample = dataset_sample - val_sample
379380
assert isinstance(val_sample, int)
380381
train_dataset, val_dataset = train_dataset.train_test_split(
381-
test_size=val_sample, seed=get_seed(random_state), load_from_cache_file=False).values()
382+
test_size=val_sample, seed=get_seed(random_state), load_from_cache_file=dataset_enable_cache).values()
382383

383384
assert train_sample > 0
384385
train_dataset = sample_dataset(train_dataset, train_sample, random_state)
@@ -445,7 +446,8 @@ def preprocess_row(row):
445446
return {'image': [], 'conversations': []}
446447
return {'image': [image]}
447448

448-
dataset = dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['conversations'])
449+
dataset = dataset.map(
450+
preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['conversations'])
449451
return ConversationsPreprocessor(
450452
user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')(
451453
dataset)
@@ -490,7 +492,7 @@ def preprocess_row(row):
490492
else:
491493
return {'images': []}
492494

493-
return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['images'])
495+
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images'])
494496

495497

496498
def get_mantis_dataset(dataset_id: str,
@@ -575,7 +577,7 @@ def preprocess_image(example):
575577
example['images'] = []
576578
return example
577579

578-
dataset = dataset.map(preprocess_image, load_from_cache_file=False).filter(lambda row: row['images'])
580+
dataset = dataset.map(preprocess_image, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images'])
579581
return ConversationsPreprocessor(
580582
user_role='user',
581583
assistant_role='assistant',
@@ -666,7 +668,7 @@ def preprocess(row):
666668
'query': np.random.choice(caption_prompt),
667669
}
668670

669-
return dataset.map(preprocess, load_from_cache_file=False)
671+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache)
670672

671673

672674
register_dataset(
@@ -717,11 +719,9 @@ def _preprocess_aishell1_dataset(dataset: HfDataset) -> HfDataset:
717719

718720

719721
def _preprocess_video_chatgpt(dataset: HfDataset) -> HfDataset:
720-
from datasets.download.download_manager import DownloadManager
721722
url = 'https://modelscope.cn/datasets/huangjintao/VideoChatGPT/resolve/master/videos.zip'
722-
with safe_ddp_context():
723-
local_dir = DownloadManager().download_and_extract(url)
724-
local_dir = os.path.join(str(local_dir), 'Test_Videos')
723+
local_dir = MediaCache.download(url, 'video_chatgpt')
724+
local_dir = os.path.join(local_dir, 'Test_Videos')
725725
# only `.mp4`
726726
mp4_set = [file[:-4] for file in os.listdir(local_dir) if file.endswith('mp4')]
727727
query = []
@@ -794,7 +794,7 @@ def map_row(row):
794794
return response
795795

796796
dataset = AlpacaPreprocessor()(dataset)
797-
return dataset.map(map_row, load_from_cache_file=False)
797+
return dataset.map(map_row, load_from_cache_file=dataset_enable_cache)
798798

799799

800800
register_dataset(
@@ -821,7 +821,7 @@ def map_row(row):
821821
title = match.group(1)
822822
return {'response': title}
823823

824-
return dataset.map(map_row, load_from_cache_file=False).filter(lambda row: row['response'])
824+
return dataset.map(map_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['response'])
825825

826826

827827
register_dataset(
@@ -1002,7 +1002,8 @@ def reorganize_row(row):
10021002
'history': history,
10031003
}
10041004

1005-
return dataset.map(reorganize_row, load_from_cache_file=False).filter(lambda row: row['query'] is not None)
1005+
return dataset.map(
1006+
reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] is not None)
10061007

10071008

10081009
register_dataset(
@@ -1067,7 +1068,7 @@ def row_can_be_parsed(row):
10671068
return False
10681069

10691070
return dataset.filter(row_can_be_parsed).map(
1070-
reorganize_row, load_from_cache_file=False).filter(lambda row: row['query'])
1071+
reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'])
10711072

10721073

10731074
register_dataset(
@@ -1137,7 +1138,8 @@ def preprocess_image(example):
11371138
return example
11381139

11391140
dataset = dataset.map(
1140-
preprocess_image, load_from_cache_file=False).filter(lambda example: example['images'] is not None)
1141+
preprocess_image,
1142+
load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None)
11411143
processer = ConversationsPreprocessor(
11421144
user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete')
11431145
return processer(dataset)
@@ -1182,8 +1184,8 @@ def preprocess(row):
11821184
return {'response': '', 'image': None}
11831185

11841186
return dataset.map(
1185-
preprocess,
1186-
load_from_cache_file=False).filter(lambda row: row.get('response')).rename_columns({'image': 'images'})
1187+
preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')).rename_columns(
1188+
{'image': 'images'})
11871189

11881190

11891191
def preprocess_refcoco_unofficial_caption(dataset):
@@ -1209,7 +1211,7 @@ def preprocess(row):
12091211
res['response'] = ''
12101212
return res
12111213

1212-
return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
1214+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response'))
12131215

12141216

12151217
register_dataset(
@@ -1254,7 +1256,7 @@ def preprocess(row):
12541256
res['response'] = ''
12551257
return res
12561258

1257-
return dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row.get('response'))
1259+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response'))
12581260

12591261

12601262
register_dataset(
@@ -1323,7 +1325,8 @@ def preprocess_image(example):
13231325
return example
13241326

13251327
dataset = dataset.map(
1326-
preprocess_image, load_from_cache_file=False).filter(lambda example: example['images'] is not None)
1328+
preprocess_image,
1329+
load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None)
13271330
processer = ConversationsPreprocessor(
13281331
user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete')
13291332
return processer(dataset)
@@ -1386,7 +1389,7 @@ def preprocess(row):
13861389
else:
13871390
return {'image': ''}
13881391

1389-
dataset = dataset.map(preprocess, load_from_cache_file=False).filter(lambda row: row['image'])
1392+
dataset = dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image'])
13901393
return ConversationsPreprocessor(
13911394
user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')(
13921395
dataset)
@@ -1412,7 +1415,7 @@ def reorganize_row(row):
14121415
'rejected_response': row['answer_en'],
14131416
}
14141417

1415-
return dataset.map(reorganize_row, load_from_cache_file=False)
1418+
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)
14161419

14171420

14181421
def process_ultrafeedback_kto(dataset: HfDataset):
@@ -1424,7 +1427,7 @@ def reorganize_row(row):
14241427
'label': row['label'],
14251428
}
14261429

1427-
return dataset.map(reorganize_row, load_from_cache_file=False)
1430+
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)
14281431

14291432

14301433
register_dataset(
@@ -1466,7 +1469,8 @@ def preprocess_row(row):
14661469
'response': output,
14671470
}
14681471

1469-
return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['query'] and row['response'])
1472+
return dataset.map(
1473+
preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] and row['response'])
14701474

14711475

14721476
register_dataset(
@@ -1495,7 +1499,7 @@ def preprocess_row(row):
14951499
'response': response,
14961500
}
14971501

1498-
return dataset.map(preprocess_row, load_from_cache_file=False)
1502+
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache)
14991503

15001504

15011505
register_dataset(
@@ -1537,7 +1541,7 @@ def preprocess(row):
15371541
'query': query,
15381542
}
15391543

1540-
return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
1544+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')
15411545

15421546

15431547
register_dataset(
@@ -1560,7 +1564,7 @@ def preprocess(row):
15601564
'query': query,
15611565
}
15621566

1563-
return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
1567+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')
15641568

15651569

15661570
register_dataset(
@@ -1584,7 +1588,7 @@ def preprocess(row):
15841588
'query': query,
15851589
}
15861590

1587-
return dataset.map(preprocess, load_from_cache_file=False).rename_column('image', 'images')
1591+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images')
15881592

15891593

15901594
register_dataset(
@@ -1606,7 +1610,8 @@ def preprocess_row(row):
16061610
return {'query': query, 'response': f'{solution}\nSo the final answer is:{response}'}
16071611

16081612
return dataset.map(
1609-
preprocess_row, load_from_cache_file=False).filter(lambda row: row['image']).rename_columns({'image': 'images'})
1613+
preprocess_row,
1614+
load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image']).rename_columns({'image': 'images'})
16101615

16111616

16121617
register_dataset(
@@ -1660,7 +1665,7 @@ def preprocess_row(row):
16601665

16611666
return {'images': images, 'response': response, 'objects': json.dumps(objects or [], ensure_ascii=False)}
16621667

1663-
return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['objects'])
1668+
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['objects'])
16641669

16651670

16661671
register_dataset(
@@ -1687,7 +1692,7 @@ def preprocess_row(row):
16871692
else:
16881693
return {'query': '', 'response': '', 'images': ''}
16891694

1690-
return dataset.map(preprocess_row, load_from_cache_file=False).filter(lambda row: row['query'])
1695+
return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'])
16911696

16921697

16931698
register_dataset(
@@ -1720,7 +1725,7 @@ def preprocess_row(row):
17201725
return {'messages': rounds}
17211726

17221727
dataset = dataset.map(
1723-
preprocess_row, load_from_cache_file=False).map(
1728+
preprocess_row, load_from_cache_file=dataset_enable_cache).map(
17241729
ConversationsPreprocessor(
17251730
user_role='user',
17261731
assistant_role='assistant',
@@ -1730,7 +1735,7 @@ def preprocess_row(row):
17301735
media_key='images',
17311736
media_type='image',
17321737
).preprocess,
1733-
load_from_cache_file=False)
1738+
load_from_cache_file=dataset_enable_cache)
17341739
return dataset
17351740

17361741

@@ -1787,8 +1792,8 @@ def preprocess(row):
17871792
}
17881793

17891794
return dataset.map(
1790-
preprocess,
1791-
load_from_cache_file=False).filter(lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None)
1795+
preprocess, load_from_cache_file=dataset_enable_cache).filter(
1796+
lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None)
17921797

17931798

17941799
register_dataset(
@@ -1814,7 +1819,7 @@ def preprocess(row):
18141819
'response': response,
18151820
}
18161821

1817-
return dataset.map(preprocess, load_from_cache_file=False)
1822+
return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache)
18181823

18191824

18201825
register_dataset(
@@ -2116,7 +2121,7 @@ def reorganize_row(row):
21162121
'response': convs[-1]['value']
21172122
}
21182123

2119-
return dataset.map(reorganize_row, load_from_cache_file=False)
2124+
return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache)
21202125

21212126

21222127
register_dataset(

0 commit comments

Comments
 (0)