Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ Here is a multi-image example of SFT VL dataset:
}
```

## chatml Format
## messages Format

The chatml Format is used for training thinking models and function call training:
The messages Format is used for training thinking models and function call training:

Demo data for thinking models:

Expand Down
3 changes: 2 additions & 1 deletion ernie/dataset/text_sft_reader/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
sampling_pseudo_examples,
sampling_pseudo_examples_fc,
)
from paddleformers.datasets.data_utils import pad_batch_data, round_up_to_multiple_of_8
from paddleformers.datasets.collate import pad_batch_data
from paddleformers.datasets.data_utils import round_up_to_multiple_of_8

logger = logging.getLogger(__name__)

Expand Down
14 changes: 10 additions & 4 deletions ernie/fusion_ops/common_fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def _fusion_flash_attention(
"""

if attn_mask_startend_row_indices is not None:
if attn_mask_startend_row_indices.ndim == 3:
attn_mask_startend_row_indices = attn_mask_startend_row_indices.unsqueeze(
-1
)
if use_sparse_flash_attn:
# attn_mask_startend_row_indices.ndim mush be 4
if attn_mask_startend_row_indices.ndim == 3:
attn_mask_startend_row_indices = (
attn_mask_startend_row_indices.unsqueeze(-1)
)
if rr_flash_attn is None:
out = flashmask_attention(
q,
Expand All @@ -94,6 +95,11 @@ def _fusion_flash_attention(
causal=True,
)
else:
# attn_mask_startend_row_indices.ndim mush be 3
if attn_mask_startend_row_indices.ndim == 4:
attn_mask_startend_row_indices = attn_mask_startend_row_indices.squeeze(
-1
)
attention_mask = _gen_from_sparse_attn_mask_indices(
attn_mask_startend_row_indices, q.dtype
)
Expand Down
33 changes: 29 additions & 4 deletions erniekit/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from paddleformers.trainer.trainer_utils import ShardingOption
from paddleformers.utils.log import logger
from paddleformers import __version__ as paddleformers_version
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer

from ernie.configuration import Ernie4_5_MoeConfig
from ernie.modeling_moe import Ernie4_5_MoeForCausalLM
Expand Down Expand Up @@ -418,15 +419,39 @@ def run_eval(args: Optional[dict[str, Any]] = None) -> None:
"encode_one_turn": data_args.encode_one_turn,
"use_template": data_args.use_template,
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
"truncate_packing": data_args.truncate_packing,
"stage": model_args.stage,
"is_valid": False,
"template_backend": data_args.template_backend,
"split_multi_turn": data_args.split_multi_turn,
}
from paddleformers.datasets.finetuning import collate_fn
dataset_config.update(
{
"template": data_args.template,
"train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
}
)

if dataset_config["template_backend"] == "custom":
template_instance = get_template_and_fix_tokenizer(dataset_config)
else:
template_instance = None
dataset_config.update(
{
"template_instance": template_instance,
}
)
from paddleformers.datasets.collate import collate_fn

if data_args.dataset_type == "map":
from paddleformers.datasets.finetuning import (
from paddleformers.datasets.loader import (
create_indexed_dataset as create_dataset,
)
else:
from paddleformers.datasets.finetuning import create_dataset
from paddleformers.datasets.loader import create_dataset
dataset_config.update(
{
"num_samples_each_epoch": data_args.num_samples_each_epoch,
Expand All @@ -440,11 +465,11 @@ def run_eval(args: Optional[dict[str, Any]] = None) -> None:
eval_file_path = os.path.join(data_args.offline_dataset_path, "eval")
eval_dataset = create_dataset(data_file_prefix=eval_file_path)
else:
dataset_config["is_valid"] = True
eval_dataset = create_dataset(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
is_valid=True,
**dataset_config,
)

Expand Down
22 changes: 22 additions & 0 deletions erniekit/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,25 @@ class DataArguments:
default=True,
metadata={"help": "Whether to use cls to predict RM score."},
)
truncate_packing: bool = field(
default=True,
metadata={
"help": "Whether to truncate data in packing (only valid in pretrain online dataflow)."
},
)
template: str = field(
default=None,
metadata={"help": "The chat template used in training."},
)
split_multi_turn: bool = field(
default=False,
metadata={
"help": "Whether to split multi-round dialogues into multiple pieces of data for training"
},
)
template_backend: str = field(
default="jinja",
metadata={
"help": "jinja means using apply_chat_template, custom means using a custom template"
},
)
4 changes: 4 additions & 0 deletions erniekit/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class ModelArguments:
"help": "Under use attn_mask_startend_row_indices=True, whether use sparse flash attention or not."
},
)
use_global_causal_attn: bool = field(
default=False,
metadata={"help": "Whether to use global causal attention in packing data"},
)
use_sparse_head_and_loss_fn: bool = field(
default=False,
metadata={"help": "Whether to use sparse LM Head and loss function."},
Expand Down
2 changes: 1 addition & 1 deletion erniekit/train/dpo/dpo_estimate_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# isort: off
# fmt: off
# isort: on
from paddleformers.datasets.dpo import create_dataset
from paddleformers.datasets.loader import create_dataset


def calculate_acc_steps(num_samples, train_batch, dataset_world_size, per_device_train_batch_size):
Expand Down
28 changes: 26 additions & 2 deletions erniekit/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
from paddleformers.trainer.trainer_utils import ShardingOption
from paddleformers.utils.log import logger
from paddleformers import __version__ as paddleformers_version
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer

from ernie.callbacks import LayerwiseDropoutCallback
from ernie.configuration import Ernie4_5_MoeConfig
from paddleformers.datasets.dpo import collate_fn, create_dataset
from paddleformers.datasets.collate import dpo_collate_fn as collate_fn
from paddleformers.datasets.loader import create_dataset
from ernie.modeling_moe import Ernie4_5_MoeForCausalLM
from ernie.modeling_moe_pp import Ernie4_5_MoeForCausalLMPipe
from ernie.tokenizer import Ernie4_5_Tokenizer
Expand Down Expand Up @@ -498,7 +500,29 @@ def run_dpo(
"packing": data_args.packing,
"mix_strategy": data_args.mix_strategy,
"encode_one_turn": data_args.encode_one_turn,
"stage": model_args.stage,
"is_valid": False,
"template_backend": data_args.template_backend,
}
dataset_config.update(
{
"template": data_args.template,
"train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
}
)

if dataset_config["template_backend"] == "custom":
template_instance = get_template_and_fix_tokenizer(dataset_config)
else:
template_instance = None
dataset_config.update(
{
"template_instance": template_instance,
}
)

if finetuning_args.max_steps == -1:
if data_args.mix_strategy == "random":
Expand Down Expand Up @@ -549,11 +573,11 @@ def run_dpo(
)

if finetuning_args.do_eval and finetuning_args.should_load_dataset:
dataset_config["is_valid"] = True
eval_dataset = create_dataset(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
is_valid=True,
**dataset_config,
)
logger.info("Creating dataset successfully ...")
Expand Down
35 changes: 31 additions & 4 deletions erniekit/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from paddleformers.trainer.trainer_utils import ShardingOption
from paddleformers.transformers.model_utils import unwrap_model
from paddleformers.datasets.template.template import get_template_and_fix_tokenizer
from paddleformers.data.causal_dataset import (
build_train_valid_test_datasets,
check_data_split,
Expand Down Expand Up @@ -529,15 +530,41 @@ def run_sft(
"encode_one_turn": data_args.encode_one_turn,
"use_template": data_args.use_template,
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
"truncate_packing": data_args.truncate_packing,
"stage": model_args.stage,
"is_valid": False,
"template_backend": data_args.template_backend,
"split_multi_turn": data_args.split_multi_turn,
}
from paddleformers.datasets.finetuning import collate_fn

dataset_config.update(
{
"template": data_args.template,
"train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
}
)

if dataset_config["template_backend"] == "custom":
template_instance = get_template_and_fix_tokenizer(dataset_config)
else:
template_instance = None
dataset_config.update(
{
"template_instance": template_instance,
}
)

from paddleformers.datasets.collate import collate_fn

if data_args.dataset_type == "map":
from paddleformers.datasets.finetuning import (
from paddleformers.datasets.loader import (
create_indexed_dataset as create_dataset,
)
else:
from paddleformers.datasets.finetuning import create_dataset
from paddleformers.datasets.loader import create_dataset
dataset_config.update(
{
"num_samples_each_epoch": data_args.num_samples_each_epoch,
Expand Down Expand Up @@ -570,11 +597,11 @@ def run_sft(
eval_file_path = os.path.join(data_args.offline_dataset_path, "eval")
eval_dataset = create_dataset(data_file_prefix=eval_file_path)
else:
dataset_config["is_valid"] = True
eval_dataset = create_dataset(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
is_valid=True,
**dataset_config,
)

Expand Down
5 changes: 3 additions & 2 deletions examples/configs/ERNIE-4.5-21B-A3B-Thinking/fc/run_fc_8k.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/function-call-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/function-call-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 8192
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/function-call-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/function-call-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 8192
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/ERNIE-4.5-21B-A3B-Thinking/run_eval.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
### data
eval_dataset_type: "chatml"
eval_dataset_type: "messages"
eval_dataset_path: "./examples/data/sft_think-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 8192
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/sft_think-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/sft_think-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 131072
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/sft_think-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/sft_think-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 32768
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/sft_think-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/sft_think-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 8192
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
### data
train_dataset_type: "chatml"
eval_dataset_type: "chatml"
train_dataset_type: "messages"
eval_dataset_type: "messages"
train_dataset_path: "./examples/data/sft_think-train.jsonl"
train_dataset_prob: "1.0"
eval_dataset_path: "./examples/data/sft_think-eval.jsonl"
eval_dataset_prob: "1.0"
max_seq_len: 131072
num_samples_each_epoch: 6000000
split_multi_turn: True

### model
model_name_or_path: baidu/ERNIE-4.5-21B-A3B-Thinking
Expand Down
Loading