-
Notifications
You must be signed in to change notification settings - Fork 4
SFT Nemo - Qwen #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SFT Nemo - Qwen #17
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| ## Run instructions | ||
|
|
||
| ### Update run.sh | ||
| - Any changes to hyperparams used during training can be passed via flags to `model/train.py`. More details about flags are in `train.py`. | ||
| - All data processing in this example is in `model/data.py`. Repace this, or add data modules that reflect your dataset. Also remember to update `train.py` to use the correct data module. | ||
|
|
||
| ### Launch run | ||
|
|
||
| ``` | ||
| truss train push config.py | ||
| ``` | ||
|
|
||
| Upon successful submission, the CLI will output helpful information about your job, including the job-id to track your run. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from truss_train import definitions | ||
| from truss.base import truss_config | ||
|
|
||
| BASE_IMAGE = "nvcr.io/nvidia/nemo:25.07" | ||
|
|
||
| training_runtime = definitions.Runtime( | ||
| start_commands=[ | ||
| "/bin/sh -c 'chmod +x ./run.sh && ./run.sh'" | ||
| ], | ||
| environment_variables={ | ||
| "HF_TOKEN": definitions.SecretReference(name="hf_access_token"), | ||
| # "WANDB_API_KEY": definitions.SecretReference(name="wandb_api_key"), | ||
| }, | ||
| cache_config=definitions.CacheConfig( | ||
| enabled=True, | ||
| ), | ||
| checkpointing_config=definitions.CheckpointingConfig( | ||
| enabled=True, | ||
| ), | ||
| ) | ||
|
|
||
| training_compute = definitions.Compute( | ||
| accelerator=truss_config.AcceleratorSpec( | ||
| accelerator=truss_config.Accelerator.H100, | ||
| count=8, | ||
| ), | ||
| node_count=1, | ||
| ) | ||
|
|
||
| training_job = definitions.TrainingJob( | ||
| image=definitions.Image(base_image=BASE_IMAGE), | ||
| compute=training_compute, | ||
| runtime=training_runtime | ||
| ) | ||
|
|
||
| training_project = definitions.TrainingProject( | ||
| name="Nemo-qwen2.5-nemo 1node", | ||
| job=training_job | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import json | ||
| import shutil | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
|
||
| # import numpy as np | ||
| from datasets import load_dataset | ||
|
|
||
| from nemo.collections.llm.gpt.data.core import get_dataset_root | ||
| from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule | ||
| from nemo.lightning.io.mixin import IOMixin | ||
| from nemo.utils import logging | ||
|
|
||
| from functools import lru_cache | ||
|
|
||
| from nemo.collections.llm.gpt.data.core import create_sft_dataset | ||
|
|
||
| if TYPE_CHECKING: | ||
| from nemo.collections.common.tokenizers import TokenizerSpec | ||
| from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs | ||
|
|
||
|
|
||
| class BespokeDataModule(FineTuningDataModule, IOMixin): | ||
| """A data module for fine-tuning on the Bespoke dataset. | ||
|
|
||
| This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the | ||
| "bespokelabs/Bespoke-Stratos-17k" dataset. It handles data download, preprocessing, splitting, and preparing the data | ||
| in a format suitable for training, validation, and testing. | ||
|
|
||
| Args: | ||
| force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. Defaults to False. | ||
| delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. Defaults to True. | ||
| See FineTuningDataModule for the other args | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| seq_length: int = 2048, | ||
| tokenizer: Optional["TokenizerSpec"] = None, | ||
| micro_batch_size: int = 4, | ||
| global_batch_size: int = 8, | ||
| rampup_batch_size: Optional[List[int]] = None, | ||
| force_redownload: bool = False, | ||
| delete_raw: bool = True, | ||
| seed: int = 1234, | ||
| memmap_workers: int = 1, | ||
| num_workers: int = 8, | ||
| pin_memory: bool = True, | ||
| persistent_workers: bool = False, | ||
| packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, | ||
| dataset_kwargs: Optional[Dict[str, Any]] = None, | ||
| dataset_root: str = "./bespoke", | ||
| ): | ||
| self.force_redownload = force_redownload | ||
| self.delete_raw = delete_raw | ||
|
|
||
| super().__init__( | ||
| dataset_root=dataset_root, | ||
| seq_length=seq_length, | ||
| tokenizer=tokenizer, | ||
| micro_batch_size=micro_batch_size, | ||
| global_batch_size=global_batch_size, | ||
| rampup_batch_size=rampup_batch_size, | ||
| seed=seed, | ||
| memmap_workers=memmap_workers, | ||
| # num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| persistent_workers=persistent_workers, | ||
| packed_sequence_specs=packed_sequence_specs, | ||
| dataset_kwargs=dataset_kwargs, | ||
| ) | ||
|
|
||
| def prepare_data(self) -> None: | ||
| # if train file is specified, no need to do anything | ||
| if not self.train_path.exists() or self.force_redownload: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit error prone for partial download failures. it'd just delegate to the huggingface load_dataset logic and rely on it for skipping download. |
||
| dset = self._download_data() | ||
| self._preprocess_and_split_data(dset) | ||
| super().prepare_data() | ||
|
|
||
| def _download_data(self): | ||
| logging.info(f"Downloading {self.__class__.__name__}...") | ||
| return load_dataset( | ||
| "bespokelabs/Bespoke-Stratos-17k", | ||
| cache_dir=str(self.dataset_root), | ||
| download_mode="force_redownload" if self.force_redownload else None, | ||
| ) | ||
|
|
||
| def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15): | ||
| logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...") | ||
| test_ratio = 1 - train_ratio - val_ratio | ||
| save_splits = {} | ||
| dataset = dset.get('train') | ||
| split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed) | ||
| split_dataset2 = split_dataset['test'].train_test_split( | ||
| test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed | ||
| ) | ||
| save_splits['training'] = split_dataset['train'] | ||
| save_splits['validation'] = split_dataset2['train'] | ||
| save_splits['test'] = split_dataset2['test'] | ||
|
|
||
| print("len training: ", len(save_splits['training'])) | ||
| print("len validation: ", len(save_splits['validation'])) | ||
| print("len test: ", len(save_splits['test'])) | ||
|
|
||
| for split_name, dataset in save_splits.items(): | ||
| output_file = self.dataset_root / f"{split_name}.jsonl" | ||
| with output_file.open("w", encoding="utf-8") as f: | ||
| for example in dataset: | ||
|
|
||
| conversations = example["conversations"] | ||
|
|
||
| for conversation in conversations: | ||
| if conversation["from"] == "user": | ||
| conversation["from"] = "User" | ||
| elif conversation["from"] == "assistant": | ||
| conversation["from"] = "Assistant" | ||
| else: | ||
| raise ValueError(f"Unknown role: {conversation['role']}") | ||
|
|
||
| example["mask"] = "User" | ||
| example["type"] = "VALUE_TO_TEXT" | ||
|
|
||
| f.write(json.dumps(example) + "\n") | ||
|
|
||
| logging.info(f"{split_name} split saved to {output_file}") | ||
|
|
||
| if self.delete_raw: | ||
| for p in self.dataset_root.iterdir(): | ||
| if p.is_dir(): | ||
| shutil.rmtree(p) | ||
| elif '.jsonl' not in str(p.name): | ||
| p.unlink() | ||
|
|
||
| @lru_cache | ||
| def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs): | ||
| # pylint: disable=C0115,C0116 | ||
| return create_sft_dataset( | ||
| path, | ||
| tokenizer=self.tokenizer, | ||
| seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size), | ||
| memmap_workers=self.memmap_workers, | ||
| seed=self.seed, | ||
| chat=True, | ||
| is_test=is_test, | ||
| pack_metadata_file_path=None, # packing is not supported | ||
| pad_cu_seqlens=False, | ||
| **kwargs, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import argparse | ||
| import nemo_run as run | ||
| import lightning.pytorch as pl | ||
| from nemo.collections import llm | ||
| import torch | ||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Convert Hugging Face checkpoint to NeMo format") | ||
|
|
||
| # Model ID configuration | ||
| parser.add_argument( | ||
| "--model_id", | ||
| type=str, | ||
| default="Qwen/Qwen2.5-7B-Instruct", | ||
| help="Source path for the model (e.g., model-name or local path)" | ||
| ) | ||
|
|
||
| # Overwrite option | ||
| parser.add_argument( | ||
| "--overwrite", | ||
| action="store_true", | ||
| default=True, | ||
| help="Whether to overwrite existing checkpoint" | ||
| ) | ||
|
|
||
| # Executor type | ||
| parser.add_argument( | ||
| "--executor", | ||
| type=str, | ||
| default="local", | ||
| choices=["local"], | ||
| help="Executor type to use" | ||
| ) | ||
|
Comment on lines
+26
to
+33
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we shouldn't have this since we don't support slurm, right? |
||
|
|
||
| # Output directory | ||
| parser.add_argument( | ||
| "--output-dir", | ||
| type=str, | ||
| default=None, | ||
| help="Output directory for converted checkpoint" | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
| def get_model_config(model_name): | ||
| """Get model configuration based on model name""" | ||
| model_configs = { | ||
| "Qwen/Qwen2.5-7B-Instruct": llm.qwen2_7b.model(), | ||
| } | ||
|
|
||
| if model_name not in model_configs: | ||
| raise ValueError(f"Unsupported model: {model_name}") | ||
|
|
||
| return model_configs[model_name] | ||
|
|
||
| def configure_checkpoint_conversion(args): | ||
| """Configure checkpoint conversion with command line arguments""" | ||
| model_config = get_model_config(args.model_id) | ||
|
|
||
| conversion_config = { | ||
| "model": model_config, | ||
| "source": f"hf://{args.model_id}", | ||
| "overwrite": args.overwrite, | ||
| } | ||
|
|
||
| # Add output directory if specified | ||
| if args.output_dir: | ||
| conversion_config["output_dir"] = args.output_dir | ||
|
|
||
| return run.Partial(llm.import_ckpt, **conversion_config) | ||
|
|
||
| def get_executor(executor_type): | ||
| """Get executor based on type""" | ||
| if executor_type == "local": | ||
| return run.LocalExecutor() | ||
| else: | ||
| raise ValueError(f"Unsupported executor type: {executor_type}") | ||
|
|
||
| def main(): | ||
| # Parse command line arguments | ||
| args = parse_args() | ||
|
|
||
| # Print CUDA information | ||
| print(f"CUDA available: {torch.cuda.is_available()}") | ||
| print(f"Number of GPUs: {torch.cuda.device_count()}") | ||
| print(f"Model/Source: {args.model_id}") | ||
| print(f"Overwrite: {args.overwrite}") | ||
| print(f"Executor: {args.executor}") | ||
|
|
||
| # Configure checkpoint conversion | ||
| import_ckpt = configure_checkpoint_conversion(args) | ||
|
|
||
| # Define executor | ||
| executor = get_executor(args.executor) | ||
|
|
||
| # Run experiment | ||
| print("Starting checkpoint conversion...") | ||
| run.run(import_ckpt, executor=executor) | ||
| print("Checkpoint conversion completed!") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a basic linter to this repo and clean up all the newline ends and formating in a seprate PR? check w Nico what's preferred for our repos.