|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import gc |
15 | 16 | import os
|
16 | 17 | import sys
|
17 | 18 | from functools import partial
|
18 | 19 |
|
19 | 20 | import paddle
|
20 | 21 |
|
| 22 | +from paddleformers.datasets.data_utils import estimate_training |
21 | 23 | from paddleformers.datasets.finetuning import collate_fn
|
22 | 24 | from paddleformers.datasets.finetuning import create_dataset as create_dataset_sft
|
23 | 25 | from paddleformers.peft import LoRAConfig, LoRAModel
|
24 |
| -from paddleformers.trainer import PdArgumentParser, get_last_checkpoint, set_seed |
| 26 | +from paddleformers.trainer import ( |
| 27 | + IntervalStrategy, |
| 28 | + PdArgumentParser, |
| 29 | + get_last_checkpoint, |
| 30 | + set_seed, |
| 31 | +) |
25 | 32 | from paddleformers.transformers import (
|
26 | 33 | AutoConfig,
|
27 | 34 | AutoModelForCausalLM,
|
@@ -155,7 +162,7 @@ def main():
|
155 | 162 | if model_args.fuse_attention_ffn is not None:
|
156 | 163 | model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
|
157 | 164 | model_config.pp_seg_method = training_args.pp_seg_method
|
158 |
| - model_config.seq_length = data_args.max_length |
| 165 | + model_config.seq_length = training_args.max_seq_len |
159 | 166 | model_config.max_sequence_length = training_args.max_seq_len
|
160 | 167 | model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
|
161 | 168 | logger.info(f"Final model config: {model_config}")
|
@@ -262,6 +269,47 @@ def neft_post_hook(module, input, output):
|
262 | 269 | model_args=model_args,
|
263 | 270 | max_seq_len=training_args.max_seq_len + model_config.num_nextn_predict_layers,
|
264 | 271 | )
|
| 272 | + |
| 273 | + if training_args.max_steps == -1: |
| 274 | + if data_args.mix_strategy == "random": |
| 275 | + raise ValueError( |
| 276 | + "When using 'random' mix_strategy, max_steps must be explicitly set (cannot be -1). " |
| 277 | + "Random mixing requires a fixed number of training steps to properly sample data." |
| 278 | + ) |
| 279 | + if paddle.distributed.get_rank() == 0: |
| 280 | + training_args.max_steps = estimate_training(train_dataset, data_args, training_args, model_args) |
| 281 | + del train_dataset |
| 282 | + gc.collect() |
| 283 | + train_dataset = create_dataset_sft( |
| 284 | + task_group=data_args.train_dataset_path, |
| 285 | + task_group_prob=data_args.train_dataset_prob, |
| 286 | + sub_dataset_type=data_args.train_dataset_type, |
| 287 | + **dataset_config, |
| 288 | + ) |
| 289 | + |
| 290 | + if paddle.distributed.get_world_size() > 1: |
| 291 | + paddle.distributed.barrier() |
| 292 | + max_steps = paddle.to_tensor([training_args.max_steps]) |
| 293 | + paddle.distributed.broadcast(max_steps, src=0) |
| 294 | + training_args.max_steps = int(max_steps.item()) |
| 295 | + if training_args.max_steps <= 0: |
| 296 | + raise ValueError(f"Invalid max_steps: {training_args.max_steps}. Please check your dataset") |
| 297 | + |
| 298 | + logger.info(f"Re-setting training_args.max_steps to {training_args.max_steps}.") |
| 299 | + # Create the learning_rate sheduler and optimizer |
| 300 | + if training_args.decay_steps is None: |
| 301 | + training_args.decay_steps = training_args.max_steps |
| 302 | + |
| 303 | + if training_args.save_strategy == IntervalStrategy.EPOCH: |
| 304 | + training_args.save_strategy = IntervalStrategy.STEPS |
| 305 | + training_args.save_steps = int(training_args.max_steps / training_args.num_train_epochs) |
| 306 | + if training_args.evaluation_strategy == IntervalStrategy.EPOCH: |
| 307 | + training_args.evaluation_strategy = IntervalStrategy.STEPS |
| 308 | + training_args.eval_steps = int(training_args.max_steps / training_args.num_train_epochs) |
| 309 | + if training_args.logging_strategy == IntervalStrategy.EPOCH: |
| 310 | + training_args.logging_strategy = IntervalStrategy.STEPS |
| 311 | + training_args.logging_steps = int(training_args.max_steps / training_args.num_train_epochs) |
| 312 | + |
265 | 313 | trainer = SFTTrainer(
|
266 | 314 | model=model,
|
267 | 315 | args=training_args,
|
|
0 commit comments