Skip to content

Commit 2e02ee6

Browse files
BC-AJintao-Huang
andauthored
Resume sample (#3460)
--------- Co-authored-by: Jintao Huang <[email protected]>
1 parent 77924b3 commit 2e02ee6

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

swift/llm/argument/sampling_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SamplingArguments(BaseArguments):
2525
sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt'
2626
output_dir: str = 'sample_output'
2727
output_file: Optional[str] = None
28+
resume: bool = False
2829
override_exist_file: bool = False
2930
num_return_sequences: int = 64
3031
num_sampling_per_gpu_batch_size: int = 1

swift/llm/sampling/sampling.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import time
55
from typing import List, Union
66

7+
import json
8+
79
from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
810
from swift.utils import get_logger
911

@@ -49,28 +51,53 @@ def _get_dataset(self):
4951
def run(self):
5052
os.makedirs(self.args.output_dir, exist_ok=True)
5153
iter_file = os.path.join(self.args.output_dir, self.args.output_file)
54+
resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
5255
tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
56+
ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
5357
if os.path.exists(iter_file) and not self.args.override_exist_file:
5458
return
55-
if os.path.exists(tmp_file):
56-
os.remove(tmp_file)
59+
60+
index_resume = -1
61+
write_mode = 'w'
62+
if self.args.resume:
63+
write_mode = 'a'
64+
if os.path.exists(resume_file):
65+
shutil.copyfile(resume_file, tmp_file)
66+
67+
if os.path.exists(ckpt_state_file):
68+
with open(ckpt_state_file, 'r') as ckpt_state:
69+
data = json.load(ckpt_state)
70+
index_resume = data.get('index', -1)
71+
logger.info(f'Loaded index_resume: {index_resume}')
72+
else:
73+
if os.path.exists(tmp_file):
74+
os.remove(tmp_file)
75+
5776
dataset = self._get_dataset()
5877
dataset_len = len(dataset)
5978
total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
79+
6080
if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
6181
self.args.num_sampling_per_gpu_batches = total_iters
6282

63-
with open(tmp_file, 'w') as f:
83+
with open(tmp_file, write_mode) as f:
6484
for _index in range(self.args.num_sampling_per_gpu_batches):
85+
if _index <= index_resume:
86+
continue
6587
logger.info(f' Sampling index:{_index}')
6688
slices = dataset[self.args.num_sampling_per_gpu_batch_size
6789
* _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
6890
slices = self.sampler.truncate_input(slices)
6991
generated = self.sampler.do_sample(slices)
7092
f.writelines(generated)
93+
f.flush()
94+
shutil.copy(tmp_file, resume_file)
95+
with open(ckpt_state_file, 'w') as ckpt_state:
96+
json.dump({'index': _index}, ckpt_state)
97+
7198
if os.path.exists(iter_file):
7299
shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
73-
shutil.move(tmp_file, iter_file)
100+
shutil.move(resume_file, iter_file)
74101
logger.info(f'Sample file {iter_file} generated.')
75102

76103

0 commit comments

Comments
 (0)