|
140 | 140 | DefaultFlowCallback,
|
141 | 141 | PrinterCallback,
|
142 | 142 | ProgressCallback,
|
| 143 | + SPGradSyncCallback, |
143 | 144 | TrainerCallback,
|
144 | 145 | TrainerControl,
|
145 | 146 | TrainerState,
|
@@ -444,9 +445,8 @@ def _save_ckpt_func(state_dict, path, signal_path=None):
|
444 | 445 | ), "should_save_sharding_stage1_model should be True when using zero cost checkpoint"
|
445 | 446 | assert (
|
446 | 447 | ShardingOption.FULL_SHARD not in self.args.sharding
|
447 |
| - ), "FULL_SHARD is not supported when using zero cost checkpoint" |
448 |
| - assert not self.args.save_tokenizer, "save_tokenizer is not supported when using zero cost checkpoint" |
449 |
| - assert not self.args.save_rng_states, "save_rng_states is not supported when using zero cost checkpoint" |
| 448 | + ), "FULL_SHARD is not supported when using flash save mode" |
| 449 | + assert not self.args.save_tokenizer, "save_tokenizer is not supported when using flash save mode" |
450 | 450 |
|
451 | 451 | # init attributes for zero cost checkpoint mode
|
452 | 452 | self.zcc_manager = None
|
@@ -2021,34 +2021,18 @@ def _load_rng_state(self, checkpoint):
|
2021 | 2021 | if checkpoint is None:
|
2022 | 2022 | return
|
2023 | 2023 |
|
2024 |
| - # if use distributed training |
2025 |
| - if self.args.world_size > 1: |
2026 |
| - process_index = self.args.process_index |
2027 |
| - rng_file_list = [None for x in range(self.args.world_size)] |
2028 |
| - if self.args.should_save: |
2029 |
| - rng_file = os.path.join(checkpoint, f"rng_state_{self.args.world_size}.pth") |
2030 |
| - if os.path.isfile(rng_file): |
2031 |
| - rng_file_list = paddle.load(rng_file, return_numpy=True) |
2032 |
| - paddle.distributed.broadcast_object_list(rng_file_list, src=0) |
2033 |
| - # if rng_file_list still empty, not log rng state. |
2034 |
| - if rng_file_list[0] is None: |
2035 |
| - logger.info( |
2036 |
| - f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " |
2037 |
| - "wasn't launched in a distributed fashion, reproducibility is not guaranteed." |
2038 |
| - ) |
2039 |
| - return |
2040 |
| - else: |
2041 |
| - checkpoint_rng_state = rng_file_list[process_index] |
2042 |
| - else: |
2043 |
| - rng_file = os.path.join(checkpoint, "rng_state.pth") |
2044 |
| - if not os.path.isfile(rng_file): |
2045 |
| - logger.info( |
2046 |
| - "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
2047 |
| - "fashion, reproducibility is not guaranteed." |
2048 |
| - ) |
2049 |
| - return |
| 2024 | + rng_file = os.path.join(checkpoint, f"rng_state_{dist.get_rank()}.pth") |
| 2025 | + if not os.path.isfile(rng_file): |
| 2026 | + logger.info( |
| 2027 | + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
| 2028 | + "fashion, reproducibility is not guaranteed." |
| 2029 | + ) |
| 2030 | + return |
2050 | 2031 |
|
2051 |
| - checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) |
| 2032 | + checkpoint_rng_state = paddle.load(rng_file, return_numpy=True) |
| 2033 | + if checkpoint_rng_state.get("world_size", None) != self.args.world_size: |
| 2034 | + logger.warn("Cannot load rng states when changing world size of training job.") |
| 2035 | + return |
2052 | 2036 |
|
2053 | 2037 | random.setstate(checkpoint_rng_state["python"])
|
2054 | 2038 | np.random.set_state(checkpoint_rng_state["numpy"])
|
@@ -2210,11 +2194,6 @@ def _wrap_model(self, model, training=True):
|
2210 | 2194 | else:
|
2211 | 2195 | model, self.optimizer = decorated
|
2212 | 2196 |
|
2213 |
| - if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: |
2214 |
| - register_sequence_parallel_allreduce_hooks( |
2215 |
| - model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce |
2216 |
| - ) |
2217 |
| - |
2218 | 2197 | if self.args.world_size == 1:
|
2219 | 2198 | if self.args.amp_master_grad:
|
2220 | 2199 | mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
|
@@ -2403,6 +2382,17 @@ def get_expected_keys(inputs, keys):
|
2403 | 2382 | ):
|
2404 | 2383 | self.optimizer._set_broadcast_overlap(True, model)
|
2405 | 2384 |
|
| 2385 | + # use callback for sp grad sync in case of unexpected behaviour (except sharding stage 2&3) |
| 2386 | + if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel: |
| 2387 | + if ShardingOption.SHARD_GRAD_OP in self.args.sharding or ShardingOption.FULL_SHARD in self.args.sharding: |
| 2388 | + register_sequence_parallel_allreduce_hooks( |
| 2389 | + unwrap_model(model), |
| 2390 | + self.args.gradient_accumulation_steps, |
| 2391 | + self.args.fuse_sequence_parallel_allreduce, |
| 2392 | + ) |
| 2393 | + else: |
| 2394 | + self.add_callback(SPGradSyncCallback(model._layers)) |
| 2395 | + |
2406 | 2396 | return model
|
2407 | 2397 |
|
2408 | 2398 | def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor, Any]:
|
@@ -2739,28 +2729,24 @@ def _save_checkpoint(self, model, metrics=None):
|
2739 | 2729 | if self.args.should_save:
|
2740 | 2730 | self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
2741 | 2731 |
|
2742 |
| - # Save RNG state in non-distributed training |
2743 |
| - rng_states = { |
2744 |
| - "python": random.getstate(), |
2745 |
| - "numpy": np.random.get_state(), |
2746 |
| - "cuda": paddle.get_rng_state(), |
2747 |
| - "cpu": paddle.framework.core.default_cpu_generator().get_state(), |
2748 |
| - } |
2749 |
| - if self.args.use_hybrid_parallel: |
2750 |
| - rng_states[ |
2751 |
| - "hybrid_parallel_rng_state_tracker" |
2752 |
| - ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() |
| 2732 | + if self.args.save_rng_states: |
| 2733 | + # Save RNG state in non-distributed training |
| 2734 | + rng_states = { |
| 2735 | + "python": random.getstate(), |
| 2736 | + "numpy": np.random.get_state(), |
| 2737 | + "cuda": paddle.get_rng_state(), |
| 2738 | + "cpu": paddle.framework.core.default_cpu_generator().get_state(), |
| 2739 | + "world_size": self.args.world_size, |
| 2740 | + } |
| 2741 | + if self.args.use_hybrid_parallel: |
| 2742 | + rng_states[ |
| 2743 | + "hybrid_parallel_rng_state_tracker" |
| 2744 | + ] = fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() |
2753 | 2745 |
|
2754 | 2746 | if self.args.save_rng_states:
|
2755 |
| - if self.args.world_size > 1: |
2756 |
| - rng_states_list = [] |
2757 |
| - paddle.distributed.all_gather_object(rng_states_list, rng_states) |
2758 |
| - if self.args.should_save: |
2759 |
| - os.makedirs(output_dir, exist_ok=True) |
2760 |
| - paddle.save(rng_states_list, os.path.join(output_dir, f"rng_state_{self.args.world_size}.pth")) |
2761 |
| - else: |
2762 |
| - os.makedirs(output_dir, exist_ok=True) |
2763 |
| - paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) |
| 2747 | + rng_state_file = os.path.join(output_dir, f"rng_state_{dist.get_rank()}.pth") |
| 2748 | + os.makedirs(output_dir, exist_ok=True) |
| 2749 | + paddle.save(rng_states, rng_state_file) |
2764 | 2750 |
|
2765 | 2751 | # only save model state dict, ignore optimizer and scheduler
|
2766 | 2752 | if not self.args.ignore_save_lr_and_optim:
|
|
0 commit comments