Skip to content

Commit ad14dc4

Browse files
authored
[Trainer] Support skip data intervals (#8989)
* support skip data intervals * add debug_data arg * fix loss compute * remove callback while skip data * remove debug data * add callback_handler * remove debug_data * fix conflict
1 parent 353d278 commit ad14dc4

File tree

4 files changed

+116
-18
lines changed

4 files changed

+116
-18
lines changed

paddlenlp/trainer/argparser.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@
2424
from enum import Enum
2525
from inspect import isclass
2626
from pathlib import Path
27-
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
27+
from typing import (
28+
Any,
29+
Dict,
30+
Iterable,
31+
NewType,
32+
Optional,
33+
Tuple,
34+
Union,
35+
get_args,
36+
get_type_hints,
37+
)
2838

2939
DataClass = NewType("DataClass", Any)
3040
DataClassType = NewType("DataClassType", Any)
@@ -129,7 +139,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
129139
# This is the value that will get picked if we do --field_name (without value)
130140
kwargs["const"] = True
131141
elif isclass(origin_type) and issubclass(origin_type, list):
132-
kwargs["type"] = field.type.__args__[0]
142+
# supprt one dimension list and two dimension list
143+
if hasattr(get_args(field.type)[0], "__args__"):
144+
kwargs["type"] = field.type.__args__[0].__args__[0]
145+
kwargs["action"] = "append"
146+
else:
147+
kwargs["type"] = field.type.__args__[0]
148+
133149
kwargs["nargs"] = "+"
134150
if field.default_factory is not dataclasses.MISSING:
135151
kwargs["default"] = field.default_factory()

paddlenlp/trainer/trainer.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
PREFIX_CHECKPOINT_DIR,
128128
EvalLoopOutput,
129129
EvalPrediction,
130+
IntervalStrategy,
130131
IterableDatasetShard,
131132
OptimizerNames,
132133
PredictionOutput,
@@ -139,6 +140,7 @@
139140
get_scheduler,
140141
has_length,
141142
set_seed,
143+
should_skip_data,
142144
speed_metrics,
143145
)
144146
from .training_args import TrainingArguments
@@ -287,9 +289,16 @@ def __init__(
287289

288290
# Seed must be set before instantiating the model when using model
289291
set_seed(seed=self.args.seed)
290-
292+
self._skip_global_steps = 0 # total skip global steps
293+
self._skip_steps_since_last_logged = 0 # skip steps since last logged
291294
if model is None:
292-
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
295+
logger.warning("Model is None.")
296+
self.model = None
297+
self.train_dataset = train_dataset
298+
self.tokenizer = tokenizer
299+
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
300+
self.data_collator = data_collator if data_collator is not None else default_collator
301+
return
293302

294303
if self.args.to_static:
295304
model = paddle.jit.to_static(model)
@@ -945,6 +954,7 @@ def _inner_training_loop(
945954
step_control = 0 # used in loop control, reset to 0 after every step
946955
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
947956

957+
step = -1
948958
for step, inputs in enumerate(epoch_iterator):
949959
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
950960
inputs = split_inputs_sequence_dim(inputs)
@@ -981,6 +991,44 @@ def _inner_training_loop(
981991
steps_trained_progress_bar.close()
982992
steps_trained_progress_bar = None
983993

994+
if should_skip_data(self.state.global_step, self.args.skip_data_intervals):
995+
# skip this step
996+
997+
if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or (
998+
# last step in epoch but step is always smaller than gradient_accumulation_steps
999+
steps_in_epoch <= args.gradient_accumulation_steps
1000+
and (step + 1) == steps_in_epoch
1001+
):
1002+
# update current global step and skip step
1003+
self.state.global_step += 1
1004+
self._skip_global_steps += 1
1005+
self._skip_steps_since_last_logged += 1
1006+
1007+
self.state.epoch = epoch + (step + 1) / steps_in_epoch
1008+
1009+
if self.state.global_step == 1 and self.args.logging_first_step:
1010+
self.control.should_log = True
1011+
if (
1012+
self.args.logging_strategy == IntervalStrategy.STEPS
1013+
and self.state.global_step % self.args.logging_steps == 0
1014+
):
1015+
self.control.should_log = True
1016+
1017+
self.control.should_evaluate = False
1018+
self.control.should_save = False
1019+
1020+
# log loss and memeory usage
1021+
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
1022+
self._print_timer()
1023+
step_control = 0
1024+
else:
1025+
step_control += 1
1026+
if self.state.global_step >= self.state.max_steps:
1027+
break
1028+
1029+
self.timers and self.timers("read-data").start()
1030+
continue
1031+
9841032
if step_control % args.gradient_accumulation_steps == 0:
9851033
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
9861034
self.timers and self.timers("forward-backward").start()
@@ -1202,7 +1250,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
12021250
)
12031251

12041252
self._total_loss_scalar += tr_loss.item()
1205-
train_loss = self._total_loss_scalar / self.state.global_step
1253+
1254+
# In case all steps were skipped, the total loss is set to 0.
1255+
if self.state.global_step == self._skip_global_steps:
1256+
logger.info("All steps were skipped, the total loss is set to 0.")
1257+
train_loss = 0.0
1258+
else:
1259+
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps)
12061260

12071261
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
12081262

@@ -1321,15 +1375,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13211375
if self.control.should_log:
13221376

13231377
logs: Dict[str, float] = {}
1324-
1378+
num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged
1379+
self._skip_steps_since_last_logged = 0
13251380
# all_gather + mean() to get average loss over all processes
13261381
avg_loss = self._nested_gather(tr_loss).mean()
13271382
tr_loss_scalar = self._get_item_from_loss(avg_loss)
13281383

13291384
# reset tr_loss to zero
13301385
tr_loss.subtract_(tr_loss)
1386+
# set loss to zero if all steps are skipped since last log
1387+
if num_steps == 0:
1388+
logs["loss"] = 0.0
1389+
else:
1390+
logs["loss"] = round(tr_loss_scalar / num_steps, 8)
13311391

1332-
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
13331392
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
13341393
logs["global_step"] = int(self.state.global_step)
13351394
if in_auto_parallel_align_mode():
@@ -1352,7 +1411,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13521411
total_train_batch_size = (
13531412
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
13541413
)
1355-
num_steps = self.state.global_step - self._globalstep_last_logged
1414+
13561415
seq_length = None
13571416
model_flops = None
13581417
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
@@ -1362,16 +1421,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
13621421
except NotImplementedError:
13631422
model_flops = None
13641423

1365-
logs.update(
1366-
speed_metrics(
1367-
"interval",
1368-
self._globalstep_last_start_time,
1369-
num_samples=total_train_batch_size * num_steps,
1370-
num_steps=num_steps,
1371-
seq_length=seq_length,
1372-
model_flops=model_flops,
1424+
# Do not log speed metrics if all steps are skipped since last log.
1425+
if num_steps > 0:
1426+
logs.update(
1427+
speed_metrics(
1428+
"interval",
1429+
self._globalstep_last_start_time,
1430+
num_samples=total_train_batch_size * num_steps,
1431+
num_steps=num_steps,
1432+
seq_length=seq_length,
1433+
model_flops=model_flops,
1434+
)
13731435
)
1374-
)
13751436

13761437
self._total_loss_scalar += tr_loss_scalar
13771438
self._globalstep_last_logged = self.state.global_step
@@ -3255,7 +3316,7 @@ def _set_signature_columns_if_needed(self):
32553316
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
32563317

32573318
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
3258-
if not self.args.remove_unused_columns:
3319+
if not self.args.remove_unused_columns or self.model is None:
32593320
return dataset
32603321
if self._signature_columns is None:
32613322
# Inspect model forward signature to keep only the arguments it accepts.

paddlenlp/trainer/trainer_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,3 +1105,20 @@ def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
11051105
tracker.add("global_seed", global_seed)
11061106
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
11071107
tracker.add("local_seed", local_seed)
1108+
1109+
1110+
def should_skip_data(global_step, skip_data_intervals):
1111+
"""Whether to skip current step data"""
1112+
1113+
if skip_data_intervals is None:
1114+
return False
1115+
skip_flag = False
1116+
for interval in skip_data_intervals:
1117+
if len(interval) != 2 or interval[0] > interval[1] or interval[0] <= 0:
1118+
raise ValueError(f"Please check your skip interval {interval}")
1119+
start_global_step, end_global_step = interval[0], interval[1]
1120+
# start_global_step and end_global_step start from 1, while global_step start from 0
1121+
if start_global_step <= global_step + 1 <= end_global_step:
1122+
skip_flag = True
1123+
break
1124+
return skip_flag

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,10 @@ class TrainingArguments:
867867
release_grads: Optional[bool] = field(
868868
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
869869
)
870+
skip_data_intervals: Optional[List[List[int]]] = field(
871+
default=None,
872+
metadata={"help": "The intervals to skip, pass start global step and end global step at each interval"},
873+
)
870874

871875
def __post_init__(self):
872876
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))

0 commit comments

Comments
 (0)