Skip to content

Commit 8b96660

Browse files
committed
Add support for finetuning and huggingface datasets
Signed-off-by: Hemil Desai <[email protected]>
1 parent 752edac commit 8b96660

File tree

10 files changed

+710
-59
lines changed

10 files changed

+710
-59
lines changed

nemo/tron/api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Callable
15+
from typing import Callable, Optional
1616

1717
from nemo.tron.checkpointing import save_checkpoint
1818
from nemo.tron.config import ConfigContainer
19-
from nemo.tron.data.dataset import train_valid_test_datasets_provider
19+
from nemo.tron.data.utils import get_dataset_provider
2020
from nemo.tron.eval import evaluate_and_print_results
2121
from nemo.tron.setup import setup
2222
from nemo.tron.train import _finish_train, train
@@ -26,9 +26,12 @@
2626
def megatron_pretrain(
2727
config: ConfigContainer,
2828
forward_step_func: Callable,
29-
dataset_provider: Callable = train_valid_test_datasets_provider,
29+
dataset_provider: Optional[Callable] = None,
3030
):
3131
## SETUP ##
32+
if dataset_provider is None:
33+
dataset_provider = get_dataset_provider(config.dataset_config)
34+
3235
setup_output = setup(config, dataset_provider)
3336
state = setup_output.state
3437
model = setup_output.model

nemo/tron/config.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import os
1616
from dataclasses import dataclass, field
17-
from typing import List, Literal, Optional
17+
from pathlib import Path
18+
from typing import Any, List, Literal, Optional, Union
1819

1920
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig as MCoreGPTDatasetConfig
2021
from megatron.core.distributed import DistributedDataParallelConfig
@@ -104,8 +105,8 @@ class TokenizerConfig:
104105
padded_vocab_size: Optional[int] = None
105106

106107

107-
@dataclass
108-
class GPTDatasetConfig(MCoreGPTDatasetConfig):
108+
@dataclass(kw_only=True)
109+
class DataloaderConfig:
109110
dataloader_type: Optional[Literal["single", "cyclic", "external"]] = None
110111
"""Single pass vs multiple pass data loader"""
111112

@@ -115,6 +116,9 @@ class GPTDatasetConfig(MCoreGPTDatasetConfig):
115116
data_sharding: bool = True
116117
"""Disable data sharding."""
117118

119+
120+
@dataclass
121+
class GPTDatasetConfig(MCoreGPTDatasetConfig, DataloaderConfig):
118122
def __post_init__(self) -> None:
119123
super(MCoreGPTDatasetConfig, self).__post_init__()
120124

@@ -123,6 +127,17 @@ def __post_init__(self) -> None:
123127
assert self.eod_mask_loss is not None
124128

125129

130+
@dataclass(kw_only=True)
131+
class FinetuningDatasetConfig(DataloaderConfig):
132+
dataset_root: Optional[Union[str, Path]] = None
133+
seq_length: int = 1024
134+
seed: int = 1234
135+
memmap_workers: int = 1
136+
max_train_samples: Optional[int] = None
137+
packed_sequence_specs: Optional[dict] = None
138+
dataset_kwargs: Optional[dict[str, Any]] = None
139+
140+
126141
@dataclass
127142
class TrainingConfig:
128143
# ---------------- Training config. ----------------
@@ -512,7 +527,7 @@ class ConfigContainer:
512527
optimizer_config: OptimizerConfig
513528
ddp_config: DistributedDataParallelConfig = field(default_factory=DistributedDataParallelConfig)
514529
scheduler_config: SchedulerConfig
515-
dataset_config: GPTDatasetConfig
530+
dataset_config: GPTDatasetConfig | FinetuningDatasetConfig
516531
logger_config: LoggerConfig
517532
tokenizer_config: TokenizerConfig
518533
checkpoint_config: CheckpointConfig
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import lru_cache
16+
from pathlib import Path
17+
from typing import Any, Callable, Optional, Union
18+
19+
import torch
20+
21+
from nemo.collections.llm.gpt.data.core import create_sft_dataset
22+
from nemo.tron.tokenizers.tokenizer import _HuggingFaceTokenizer
23+
from nemo.tron.utils.common_utils import get_rank_safe, print_rank_0
24+
from nemo.utils import logging
25+
26+
27+
class FinetuningDatasetBuilder:
28+
"""Builder class for fine-tuning datasets.
29+
30+
This class provides methods to build datasets for fine-tuning large language models.
31+
It follows a builder pattern similar to BlendedMegatronDatasetBuilder but adapted for
32+
fine-tuning scenarios.
33+
34+
Args:
35+
dataset_root (Union[str, Path]): The root directory containing training, validation, and test data.
36+
tokenizer: The tokenizer to use for preprocessing text.
37+
seq_length (int, optional): The maximum sequence length. Defaults to 2048.
38+
seed (int, optional): Random seed for data shuffling. Defaults to 1234.
39+
memmap_workers (int, optional): Number of worker processes for memmap datasets. Defaults to 1.
40+
is_built_on_rank (Callable): Function that returns True if the dataset should be built on current rank.
41+
max_train_samples (int, optional): Maximum number of training samples. Defaults to None.
42+
packed_sequence_specs (Optional[dict], optional): Specifications for packed sequences. Defaults to None.
43+
dataset_kwargs (Optional[dict[str, Any]], optional): Additional dataset creation arguments. Defaults to None.
44+
"""
45+
46+
def __init__(
47+
self,
48+
dataset_root: Union[str, Path],
49+
tokenizer,
50+
is_built_on_rank: Callable,
51+
seq_length: int = 2048,
52+
seed: int = 1234,
53+
memmap_workers: int = 1,
54+
max_train_samples: Optional[int] = None,
55+
packed_sequence_specs: Optional[dict] = None,
56+
dataset_kwargs: Optional[dict[str, Any]] = None,
57+
):
58+
self.dataset_root = Path(dataset_root)
59+
self.tokenizer = tokenizer
60+
self.seq_length = seq_length
61+
self.seed = seed
62+
self.memmap_workers = memmap_workers
63+
self.is_built_on_rank = is_built_on_rank
64+
self.max_train_samples = max_train_samples
65+
self.packed_sequence_specs = packed_sequence_specs
66+
self.packed_sequence_size = (
67+
-1 if not packed_sequence_specs else packed_sequence_specs.get("packed_sequence_size", -1)
68+
)
69+
self.dataset_kwargs = dataset_kwargs or {}
70+
self._pad_cu_seqlens = (
71+
False if not packed_sequence_specs else packed_sequence_specs.get("pad_cu_seqlens", False)
72+
)
73+
74+
print_rank_0(f"Building FinetuningDatasetBuilder with root={self.dataset_root}")
75+
76+
if self.packed_sequence_size > 0:
77+
print_rank_0(f"Using packed sequences with size {self.packed_sequence_size}")
78+
79+
def prepare_data(self) -> None:
80+
"""Prepare data if needed."""
81+
self.prepare_packed_data()
82+
83+
def prepare_packed_data(self) -> None:
84+
"""Prepare packed sequence data if needed."""
85+
if self.packed_sequence_size > 0:
86+
from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data
87+
88+
if not self.train_path_packed.is_file():
89+
print_rank_0(f"Preparing packed training data at {self.train_path_packed}")
90+
prepare_packed_sequence_data(
91+
input_path=self.train_path,
92+
output_path=self.train_path_packed,
93+
packed_sequence_size=self.packed_sequence_size,
94+
tokenizer=self.tokenizer,
95+
max_seq_length=self.seq_length,
96+
seed=self.seed,
97+
output_metadata_path=self.pack_metadata,
98+
)
99+
100+
if not self.validation_path_packed.is_file():
101+
print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}")
102+
prepare_packed_sequence_data(
103+
input_path=self.validation_path,
104+
output_path=self.validation_path_packed,
105+
packed_sequence_size=self.packed_sequence_size,
106+
tokenizer=self.tokenizer,
107+
max_seq_length=self.seq_length,
108+
seed=self.seed,
109+
output_metadata_path=self.pack_metadata,
110+
)
111+
112+
def build(self) -> list[Optional[Any]]:
113+
"""Build train, validation, and test datasets.
114+
115+
This method creates the necessary datasets based on the configuration.
116+
It first prepares packed data if needed, then builds the datasets in parallel
117+
on the appropriate ranks.
118+
119+
Returns:
120+
list[Optional[Any]]: A list containing the train, validation, and test datasets.
121+
Any of these may be None if not available or not built on current rank.
122+
"""
123+
# Prepare packed data if needed
124+
if get_rank_safe() == 0:
125+
self.prepare_data()
126+
127+
# Use a similar parallel building approach as BlendedMegatronDatasetBuilder
128+
if torch.distributed.is_initialized():
129+
rank = torch.distributed.get_rank()
130+
131+
datasets = [None, None, None] # train, valid, test
132+
133+
# First, build on rank 0
134+
if rank == 0 and self.is_built_on_rank():
135+
try:
136+
datasets = self._build_datasets()
137+
except Exception as err:
138+
logging.error(f"Failed to build datasets on rank 0: {err}")
139+
raise
140+
141+
# Synchronize all ranks
142+
torch.distributed.barrier()
143+
144+
# Then build on other ranks
145+
if rank != 0 and self.is_built_on_rank():
146+
datasets = self._build_datasets()
147+
148+
return datasets
149+
else:
150+
# Not distributed
151+
return self._build_datasets()
152+
153+
def _build_datasets(self) -> list[Optional[Any]]:
154+
"""Internal method to build all datasets.
155+
156+
Returns:
157+
list[Optional[Any]]: The train, validation, and test datasets.
158+
"""
159+
train_ds = self._create_dataset(
160+
self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed,
161+
pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata,
162+
max_num_samples=self.max_train_samples,
163+
**self.dataset_kwargs,
164+
)
165+
166+
valid_ds = self._create_dataset(
167+
self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed,
168+
pack_metadata_path=None if self.packed_sequence_size <= 0 else self.pack_metadata,
169+
is_test=True,
170+
**self.dataset_kwargs,
171+
)
172+
173+
test_ds = (
174+
self._create_dataset(
175+
self.test_path,
176+
tokens_to_generate=32,
177+
is_test=True,
178+
**self.dataset_kwargs,
179+
)
180+
if self.test_path.exists()
181+
else None
182+
)
183+
184+
return [train_ds, valid_ds, test_ds]
185+
186+
@lru_cache
187+
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
188+
"""Create a dataset from the given path and parameters.
189+
190+
Args:
191+
path: Path to the dataset file
192+
pack_metadata_path: Path to the packed sequence metadata
193+
is_test: Whether this is a test dataset
194+
**kwargs: Additional arguments to pass to the dataset constructor
195+
196+
Returns:
197+
The created dataset
198+
"""
199+
if not Path(path).exists():
200+
print_rank_0(f"Warning: Dataset path {path} does not exist")
201+
return None
202+
203+
is_not_packing = self.packed_sequence_size <= 0
204+
return create_sft_dataset(
205+
path,
206+
tokenizer=self.tokenizer,
207+
seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size),
208+
memmap_workers=self.memmap_workers,
209+
seed=self.seed,
210+
is_test=is_test,
211+
pack_metadata_file_path=None if is_not_packing else pack_metadata_path,
212+
pad_cu_seqlens=False if is_not_packing else self._pad_cu_seqlens,
213+
**kwargs,
214+
)
215+
216+
@property
217+
def train_path(self) -> Path:
218+
"""Path to training dataset file"""
219+
return self.dataset_root / "training.jsonl"
220+
221+
@property
222+
def default_pack_path(self) -> Path:
223+
"""The default directory to write packing files."""
224+
tokenizer_model_name = self._extract_tokenizer_model_name()
225+
default_pack_path = self.dataset_root / "packed" / tokenizer_model_name
226+
if not default_pack_path.exists():
227+
default_pack_path.mkdir(parents=True, exist_ok=True)
228+
logging.info(f"Using default path for packing files: {str(default_pack_path)}")
229+
230+
return default_pack_path
231+
232+
@property
233+
def pack_metadata(self) -> Path:
234+
"""Path to metadata dataset file for packed sequence."""
235+
if self.packed_sequence_size > 0:
236+
if self.packed_sequence_specs.get("packed_metadata_path") is not None:
237+
return self.packed_sequence_specs["packed_metadata_path"]
238+
return self.default_pack_path / f"{self.packed_sequence_size}_metadata.jsonl"
239+
else:
240+
raise ValueError("pack_metadata invalid since packed sequence size is not specified.")
241+
242+
@property
243+
def train_path_packed(self) -> Path:
244+
"""Path to training dataset file for packed sequence."""
245+
if self.packed_sequence_size > 0:
246+
if self.packed_sequence_specs.get("packed_train_data_path") is not None:
247+
return self.packed_sequence_specs["packed_train_data_path"]
248+
return self.default_pack_path / f"training_{self.packed_sequence_size}.npy"
249+
else:
250+
raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.")
251+
252+
@property
253+
def validation_path_packed(self) -> Path:
254+
"""Path to validation dataset file for packed sequence."""
255+
if self.packed_sequence_size > 0:
256+
if self.packed_sequence_specs.get("packed_val_data_path") is not None:
257+
return self.packed_sequence_specs["packed_val_data_path"]
258+
return self.default_pack_path / f"validation_{self.packed_sequence_size}.npy"
259+
else:
260+
raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.")
261+
262+
@property
263+
def validation_path(self) -> Path:
264+
"""Path to validation dataset file"""
265+
return self.dataset_root / "validation.jsonl"
266+
267+
@property
268+
def test_path(self) -> Path:
269+
"""Path to test dataset file"""
270+
return self.dataset_root / "test.jsonl"
271+
272+
def _extract_tokenizer_model_name(self) -> str:
273+
"""Automatically get the model name from model path."""
274+
if self.packed_sequence_specs and self.packed_sequence_specs.get("tokenizer_model_name") is not None:
275+
return self.packed_sequence_specs["tokenizer_model_name"]
276+
elif isinstance(self.tokenizer, _HuggingFaceTokenizer):
277+
name = self.tokenizer._tokenizer.name_or_path
278+
if name.endswith("context/nemo_tokenizer"):
279+
# NEMO_HOME/hf_org/hf_model/context/nemo_tokenizer => hf_org--hf_model
280+
tokenizer_model_name = "--".join(name.split("/")[-4:-2])
281+
elif name.endswith("nemo_tokenizer"):
282+
# NEMO_HOME/hf_org/hf_model/nemo_tokenizer => hf_org--hf_model
283+
tokenizer_model_name = "--".join(name.split("/")[-3:-1])
284+
else:
285+
# hf_org/hf_model => hf_org--hf_model
286+
tokenizer_model_name = name.replace("/", "--")
287+
return tokenizer_model_name
288+
else:
289+
return f"unknown_tokenizer_{hash(self.tokenizer)}"

0 commit comments

Comments
 (0)