|
9 | 9 | from transformers.trainer_utils import IntervalStrategy, has_length |
10 | 10 |
|
11 | 11 | from swift.trainers import TrainingArguments |
| 12 | +from swift.utils import is_pai_training_job |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class ProgressCallbackNew(ProgressCallback): |
@@ -47,7 +48,7 @@ def on_log(self, |
47 | 48 | for k, v in logs.items(): |
48 | 49 | if isinstance(v, float): |
49 | 50 | logs[k] = round(logs[k], 8) |
50 | | - if state.is_local_process_zero and self.training_bar is not None: |
| 51 | + if not is_pai_training_job() and state.is_local_process_zero: |
51 | 52 | jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') |
52 | 53 | with open(jsonl_path, 'a', encoding='utf-8') as f: |
53 | 54 | f.write(json.dumps(logs) + '\n') |
@@ -77,7 +78,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): |
77 | 78 | for k, v in logs.items(): |
78 | 79 | if isinstance(v, float): |
79 | 80 | logs[k] = round(logs[k], 8) |
80 | | - if state.is_local_process_zero: |
| 81 | + if not is_pai_training_job() and state.is_local_process_zero: |
81 | 82 | jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') |
82 | 83 | with open(jsonl_path, 'a', encoding='utf-8') as f: |
83 | 84 | f.write(json.dumps(logs) + '\n') |
|
0 commit comments