|
4 | 4 | import time |
5 | 5 | from typing import List, Union |
6 | 6 |
|
| 7 | +import json |
| 8 | + |
7 | 9 | from swift.llm import SamplingArguments, SwiftPipeline, load_dataset |
8 | 10 | from swift.utils import get_logger |
9 | 11 |
|
@@ -49,28 +51,53 @@ def _get_dataset(self): |
49 | 51 | def run(self): |
50 | 52 | os.makedirs(self.args.output_dir, exist_ok=True) |
51 | 53 | iter_file = os.path.join(self.args.output_dir, self.args.output_file) |
| 54 | + resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume') |
52 | 55 | tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp') |
| 56 | + ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json') |
53 | 57 | if os.path.exists(iter_file) and not self.args.override_exist_file: |
54 | 58 | return |
55 | | - if os.path.exists(tmp_file): |
56 | | - os.remove(tmp_file) |
| 59 | + |
| 60 | + index_resume = -1 |
| 61 | + write_mode = 'w' |
| 62 | + if self.args.resume: |
| 63 | + write_mode = 'a' |
| 64 | + if os.path.exists(resume_file): |
| 65 | + shutil.copyfile(resume_file, tmp_file) |
| 66 | + |
| 67 | + if os.path.exists(ckpt_state_file): |
| 68 | + with open(ckpt_state_file, 'r') as ckpt_state: |
| 69 | + data = json.load(ckpt_state) |
| 70 | + index_resume = data.get('index', -1) |
| 71 | + logger.info(f'Loaded index_resume: {index_resume}') |
| 72 | + else: |
| 73 | + if os.path.exists(tmp_file): |
| 74 | + os.remove(tmp_file) |
| 75 | + |
57 | 76 | dataset = self._get_dataset() |
58 | 77 | dataset_len = len(dataset) |
59 | 78 | total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size) |
| 79 | + |
60 | 80 | if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters: |
61 | 81 | self.args.num_sampling_per_gpu_batches = total_iters |
62 | 82 |
|
63 | | - with open(tmp_file, 'w') as f: |
| 83 | + with open(tmp_file, write_mode) as f: |
64 | 84 | for _index in range(self.args.num_sampling_per_gpu_batches): |
| 85 | + if _index <= index_resume: |
| 86 | + continue |
65 | 87 | logger.info(f' Sampling index:{_index}') |
66 | 88 | slices = dataset[self.args.num_sampling_per_gpu_batch_size |
67 | 89 | * _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)] |
68 | 90 | slices = self.sampler.truncate_input(slices) |
69 | 91 | generated = self.sampler.do_sample(slices) |
70 | 92 | f.writelines(generated) |
| 93 | + f.flush() |
| 94 | + shutil.copy(tmp_file, resume_file) |
| 95 | + with open(ckpt_state_file, 'w') as ckpt_state: |
| 96 | + json.dump({'index': _index}, ckpt_state) |
| 97 | + |
71 | 98 | if os.path.exists(iter_file): |
72 | 99 | shutil.move(iter_file, iter_file + '.' + str(int(time.time()))) |
73 | | - shutil.move(tmp_file, iter_file) |
| 100 | + shutil.move(resume_file, iter_file) |
74 | 101 | logger.info(f'Sample file {iter_file} generated.') |
75 | 102 |
|
76 | 103 |
|
|
0 commit comments