Skip to content

Commit a2110e1

Browse files
jenchen13ko3n1g
andauthored
SFT Script with HuggingFace chat template support (NVIDIA-NeMo#13273)
* chat template fix Signed-off-by: jenchen13 <jennifchen@nvidia.com> Co-authored-by: oliver könig <okoenig@nvidia.com>
1 parent a528802 commit a2110e1

File tree

9 files changed

+537
-207
lines changed

9 files changed

+537
-207
lines changed

nemo/collections/llm/api.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
@run.cli.entrypoint(namespace="llm")
6767
def train(
68-
model: pl.LightningModule,
68+
model: Union[pl.LightningModule, AnyPath],
6969
data: pl.LightningDataModule,
7070
trainer: Trainer,
7171
log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None,
@@ -79,7 +79,7 @@ def train(
7979
Trains a model using the specified data and trainer, with optional tokenizer, source, and export.
8080
8181
Args:
82-
model (pl.LightningModule): The model to be trained.
82+
model (Union[pl.LightningModule, AnyPath]): The model to be trained or a path to the NeMo 2 checkpoint.
8383
data (pl.LightningDataModule): The data module containing training data.
8484
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
8585
log (NeMoLogger): A nemologger instance.
@@ -106,6 +106,8 @@ def train(
106106
>>> llm.train(model, data, trainer, tokenizer="data")
107107
PosixPath('/path/to/log_dir')
108108
"""
109+
model = _load_model_from_path(model)
110+
109111
# [ModelOpt]: If modelopt_state exists, overwrite transformer_layer_spec to modelopt spec
110112
if resume:
111113
if resume.restore_config and resume.restore_config.path:
@@ -131,7 +133,7 @@ def train(
131133

132134
@run.cli.entrypoint(namespace="llm")
133135
def pretrain(
134-
model: pl.LightningModule,
136+
model: Union[pl.LightningModule, AnyPath],
135137
data: pl.LightningDataModule,
136138
trainer: Trainer,
137139
log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None,
@@ -145,7 +147,7 @@ def pretrain(
145147
Note, by default it will use the tokenizer from the model.
146148
147149
Args:
148-
model (pl.LightningModule): The model to be pretrained.
150+
model (Union[pl.LightningModule, AnyPath]): The model to be pretrained or a path to the NeMo 2 checkpoint.
149151
data (pl.LightningDataModule): The data module containing pretraining data.
150152
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
151153
log (NeMoLogger): A nemologger instance.
@@ -166,6 +168,7 @@ def pretrain(
166168
>>> llm.pretrain(model, data, trainer)
167169
PosixPath('/path/to/log_dir')
168170
"""
171+
model = _load_model_from_path(model)
169172
_validate_config(model, data, trainer, log=log, resume=resume, optim=optim)
170173

171174
return train(
@@ -181,28 +184,33 @@ def pretrain(
181184

182185
@run.cli.entrypoint(namespace="llm")
183186
def finetune(
184-
model: pl.LightningModule,
187+
model: Union[pl.LightningModule, AnyPath],
185188
data: pl.LightningDataModule,
186189
trainer: Trainer,
187190
log: Annotated[Optional[NeMoLogger], run.Config[NeMoLogger]] = None,
188191
resume: Annotated[Optional[AutoResume], run.Config[AutoResume]] = None,
189192
optim: Optional[OptimizerModule] = None,
190193
peft: Optional[Union[PEFT, ModelTransform, Callable]] = None,
194+
tokenizer: Optional[TokenizerType] = "model",
191195
) -> Path:
192196
"""
193197
Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT.
194198
195199
Note, by default it will use the tokenizer from the model.
196200
197201
Args:
198-
model (pl.LightningModule): The model to be finetuned.
202+
model (Union[pl.LightningModule, AnyPath]): The model to be finetuned.
199203
data (pl.LightningDataModule): The data module containing finetuning data.
200204
trainer (Trainer): The trainer instance configured with a MegatronStrategy.
201205
log (NeMoLogger): A nemologger instance.
202206
resume (Optional[AutoResume]): Resume training from a checkpoint.
203207
optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
204208
optimizer from the model will be used.
205209
peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied.
210+
tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
211+
or an instance of TokenizerSpec. If 'data' uses the data loader's tokenizer instead of the tokenizer
212+
from the model checkpoint, which is useful for expanding vocabulary or adding special tokens
213+
(such as chat template tokens).
206214
207215
Returns:
208216
Path: The directory path where finetuning artifacts are saved.
@@ -217,7 +225,7 @@ def finetune(
217225
>>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()])
218226
PosixPath('/path/to/log_dir')
219227
"""
220-
228+
model = _load_model_from_path(model)
221229
_validate_config(model, data, trainer, log=log, resume=resume, optim=optim, model_transform=peft)
222230
return train(
223231
model=model,
@@ -226,7 +234,7 @@ def finetune(
226234
log=log,
227235
resume=resume,
228236
optim=optim,
229-
tokenizer="model",
237+
tokenizer=tokenizer,
230238
model_transform=peft,
231239
)
232240

@@ -630,6 +638,7 @@ def deploy(
630638
the trtllm backend).
631639
"""
632640
import os
641+
633642
import uvicorn
634643

635644
from nemo.deploy import DeployPyTriton
@@ -1345,3 +1354,9 @@ def _build_directory_tree(path, tree=None, root_name=None):
13451354
tree.add(f"[white]{item.name}[/white]")
13461355

13471356
return tree
1357+
1358+
1359+
def _load_model_from_path(model: Union[pl.LightningModule, AnyPath]):
1360+
if isinstance(model, AnyPath):
1361+
model = io.load_context(ckpt_to_context_subdir(model), subpath="model")
1362+
return model

nemo/collections/llm/gpt/data/chat.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
# limitations under the License.
1414

1515
from functools import lru_cache
16+
from pathlib import Path
17+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1618

1719
from nemo.collections.llm.gpt.data.core import create_sft_dataset
1820
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
1921

22+
if TYPE_CHECKING:
23+
from nemo.collections.common.tokenizers import TokenizerSpec
24+
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
25+
2026

2127
class ChatDataModule(FineTuningDataModule):
2228
"""
@@ -26,6 +32,48 @@ class ChatDataModule(FineTuningDataModule):
2632
See base class `FineTuningDataModule` for more details.
2733
"""
2834

35+
def __init__(
36+
self,
37+
dataset_root: Union[str, Path],
38+
seq_length: int = 2048,
39+
tokenizer: Optional["TokenizerSpec"] = None,
40+
micro_batch_size: int = 4,
41+
global_batch_size: int = 8,
42+
rampup_batch_size: Optional[List[int]] = None,
43+
seed: int = 1234,
44+
memmap_workers: int = 1,
45+
num_workers: int = 8,
46+
pin_memory: bool = True,
47+
persistent_workers: bool = False,
48+
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
49+
dataset_kwargs: Optional[Dict[str, Any]] = None,
50+
use_hf_tokenizer_chat_template: bool = False,
51+
):
52+
"""Data module for finetuning on chat datasets.
53+
See base class `FineTuningDataModule` for more details of the arguments.
54+
55+
Args:
56+
use_hf_tokenizer_chat_template: Whether to use the chat template from the HuggingFace tokenizer. If True,
57+
uses the tokenizer's built-in chat template. If False, uses default chat template from
58+
GPTSFTChatDataset. Defaults to False.
59+
"""
60+
super().__init__(
61+
dataset_root,
62+
seq_length,
63+
tokenizer,
64+
micro_batch_size,
65+
global_batch_size,
66+
rampup_batch_size,
67+
seed,
68+
memmap_workers,
69+
num_workers,
70+
pin_memory,
71+
persistent_workers,
72+
packed_sequence_specs,
73+
dataset_kwargs,
74+
)
75+
self.use_hf_tokenizer_chat_template = use_hf_tokenizer_chat_template
76+
2977
@lru_cache
3078
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
3179
# pylint: disable=C0115,C0116
@@ -39,5 +87,6 @@ def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs
3987
is_test=is_test,
4088
pack_metadata_file_path=None, # packing is not supported
4189
pad_cu_seqlens=False,
90+
use_hf_tokenizer_chat_template=self.use_hf_tokenizer_chat_template,
4291
**kwargs,
4392
)

0 commit comments

Comments
 (0)