diff --git a/config_hub/finetune/llama-2-7b/longlora.yaml b/config_hub/finetune/llama-2-7b/longlora.yaml new file mode 100644 index 0000000000..13a0dd0b16 --- /dev/null +++ b/config_hub/finetune/llama-2-7b/longlora.yaml @@ -0,0 +1,141 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-llama2-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 8 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.0 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# LongLoRA-related arguments. See ``litgpt.args.LongLoRAArgs`` for details +longlora: + # Whether to use LongLoRA. (type: bool, default: false) + use_longlora: true + + # The enlarged context length for LongLoRA. (type: int, default: 8192) + context_length: 8192 + + # The number of groups to split the sequence into. (type: int, default: 4) + n_groups: 4 + + # The additional trainable parameters for LongLoRA. (type: str, default: "wte,norm,ln") + trainable_params: "wte,norm,ln" + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 + +# Optimizer-related arguments +optimizer: + + class_path: torch.optim.AdamW + + init_args: + + # (type: float, default: 0.001) + lr: 0.0002 + + # (type: float, default: 0.01) + weight_decay: 0.0 + + # (type: tuple, default: (0.9,0.999)) + betas: + - 0.9 + - 0.95 \ No newline at end of file diff --git a/config_hub/finetune/mistral-7b/longlora.yaml b/config_hub/finetune/mistral-7b/longlora.yaml new file mode 100644 index 0000000000..1369e1c459 --- /dev/null +++ b/config_hub/finetune/mistral-7b/longlora.yaml @@ -0,0 +1,141 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1 + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-mistral-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 8 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.0 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# LongLoRA-related arguments. See ``litgpt.args.LongLoRAArgs`` for details +longlora: + # Whether to use LongLoRA. (type: bool, default: false) + use_longlora: true + + # The enlarged context length for LongLoRA. (type: int, default: 8192) + context_length: 8192 + + # The number of groups to split the sequence into. (type: int, default: 4) + n_groups: 4 + + # The additional trainable parameters for LongLoRA. (type: str, default: "wte,norm,ln") + trainable_params: "wte,norm,ln" + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 + +# Optimizer-related arguments +optimizer: + + class_path: torch.optim.AdamW + + init_args: + + # (type: float, default: 0.001) + lr: 0.0002 + + # (type: float, default: 0.01) + weight_decay: 0.0 + + # (type: tuple, default: (0.9,0.999)) + betas: + - 0.9 + - 0.95 \ No newline at end of file diff --git a/litgpt/args.py b/litgpt/args.py index e3bac05ef2..13beb5fca2 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -36,14 +36,6 @@ class TrainArgs: max_norm: Optional[float] = None min_lr: float = 6e-5 - def __post_init__(self) -> None: - if self.lr_warmup_fraction and self.lr_warmup_steps: - raise ValueError( - "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one." - ) - if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1): - raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.") - def gradient_accumulation_iters(self, devices: int) -> int: """Number of iterations between gradient synchronizations""" gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size @@ -77,3 +69,17 @@ class EvalArgs: """Number of iterations""" initial_validation: bool = False """Whether to evaluate on the validation set at the beginning of the training""" + + +@dataclass +class LongLoraArgs: + """LongLora-related arguments""" + + use_longlora: bool = False + """Whether to enable LongLora.""" + n_groups: int = 4 + """Number of groups to divide the sequence length into.""" + context_length: int = 8192 + """Length of the enlarged context window.""" + trainable_params: str = "wte,norm,ln" + """List of comma-separated parameters to train in LongLora.""" diff --git a/litgpt/config.py b/litgpt/config.py index 0f35d43eec..dead50a5b9 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -61,6 +61,8 @@ class Config: rope_base: int = 10000 n_expert: int = 0 n_expert_per_token: int = 0 + use_longlora: bool = False + longlora_n_groups: int = 4 def __post_init__(self): if not self.name: diff --git a/litgpt/data/alpaca.py b/litgpt/data/alpaca.py index fc3d973848..05e3a62422 100644 --- a/litgpt/data/alpaca.py +++ b/litgpt/data/alpaca.py @@ -43,6 +43,7 @@ class Alpaca(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -51,11 +52,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -97,7 +103,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -106,7 +114,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/base.py b/litgpt/data/base.py index 36ef33fb8a..b4d34aeab9 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -10,6 +10,7 @@ from litgpt import Tokenizer from litgpt.prompts import PromptStyle +from litgpt.utils import find_multiple class DataModule(LightningDataModule): @@ -17,7 +18,7 @@ class DataModule(LightningDataModule): @abstractmethod def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None, **kwargs ) -> None: """All settings that can't be determined at the time of instantiation need to be passed through here before any dataloaders can be accessed. @@ -44,6 +45,7 @@ class SFTDataset(Dataset): ignore_index: The index to use for elements to be ignored in the label. transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the keys in the dataset to the expected 'instruction' and 'output' keys. + pad_multiple_of: If set, sequences will be padded to a multiple of 'pad_multiple_of'. Returns a dict with two keys: input_ids: The encoded prompt + response @@ -93,18 +95,30 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]: return {"input_ids": encoded_prompt_and_response.type(torch.int64), "labels": labels.type(torch.int64)} -def get_sft_collate_fn(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100): +def get_sft_collate_fn( + max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100, pad_multiple_of: Optional[int] = None +): """Returns the collate function for supervised finetuning (needed in the DataLoader). The collate function gets a list of dicts with keys `input_ids` and `labels`. It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in the batch. Optionally truncates all sequences to the specified maximum length. """ - return partial(_sft_collate_fn, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index) + return partial( + _sft_collate_fn, + max_seq_length=max_seq_length, + pad_id=pad_id, + ignore_index=ignore_index, + pad_multiple_of=pad_multiple_of, + ) def _sft_collate_fn( - samples: List[Dict[str, Tensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100 + samples: List[Dict[str, Tensor]], + max_seq_length: int = -1, + pad_id: int = 0, + ignore_index: int = -100, + pad_multiple_of: Optional[int] = None, ) -> Dict[str, Tensor]: batched = {} @@ -116,6 +130,19 @@ def _sft_collate_fn( [sample[key] for sample in samples], batch_first=True, padding_value=pad_value ) + # Pad to multiple of 'pad_multiple_of' + if pad_multiple_of is not None and pad_multiple_of > 1: + pad_to = find_multiple(batched[key].shape[1], pad_multiple_of) + pad_to_add = pad_to - batched[key].shape[1] + if pad_to_add > 0: + batched[key] = torch.cat( + ( + batched[key], + torch.full((batched[key].shape[0], pad_to_add, *batched[key].shape[2:]), fill_value=pad_value), + ), + dim=1, + ) + # Truncate if needed if max_seq_length > 0: batched[key] = batched[key][:, :max_seq_length] diff --git a/litgpt/data/deita.py b/litgpt/data/deita.py index c0e52d24f0..b310e3de54 100644 --- a/litgpt/data/deita.py +++ b/litgpt/data/deita.py @@ -36,6 +36,7 @@ class Deita(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -44,11 +45,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import load_dataset @@ -86,7 +92,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -95,7 +103,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/dolly.py b/litgpt/data/dolly.py index 03d973f9b2..891507e86d 100644 --- a/litgpt/data/dolly.py +++ b/litgpt/data/dolly.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import Optional, Union import torch from torch.utils.data import random_split diff --git a/litgpt/data/flan.py b/litgpt/data/flan.py index a2a5b443ac..dd2ee492e5 100644 --- a/litgpt/data/flan.py +++ b/litgpt/data/flan.py @@ -42,6 +42,7 @@ class FLAN(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -59,11 +60,16 @@ def __post_init__(self): self.subsets = list(supported_subsets) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -100,7 +106,9 @@ def _dataloader(self, split: str) -> DataLoader: shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/json_data.py b/litgpt/data/json_data.py index a40096486d..7d2878a599 100644 --- a/litgpt/data/json_data.py +++ b/litgpt/data/json_data.py @@ -38,6 +38,7 @@ class JSON(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -61,11 +62,16 @@ def __post_init__(self): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def setup(self, stage: str = "") -> None: train_data, test_data = self.get_splits() @@ -94,7 +100,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -103,7 +111,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def get_splits(self) -> Tuple: diff --git a/litgpt/data/lima.py b/litgpt/data/lima.py index 8ea3db5ebd..75a2b8d0a2 100644 --- a/litgpt/data/lima.py +++ b/litgpt/data/lima.py @@ -39,6 +39,7 @@ class LIMA(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -53,11 +54,16 @@ def __post_init__(self): self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import load_dataset @@ -102,7 +108,9 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) def val_dataloader(self) -> DataLoader: @@ -111,7 +119,9 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/lit_data.py b/litgpt/data/lit_data.py index 8347215fbd..5bacd8705f 100644 --- a/litgpt/data/lit_data.py +++ b/litgpt/data/lit_data.py @@ -29,16 +29,22 @@ class LitData(DataModule): batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) + pad_multiple_of: Optional[int] = field(init=False, repr=False, default=None) def __post_init__(self) -> None: if self.split_names is not None and len(self.split_names) != 2: raise ValueError("If provided `split_names` must be a tuple of two strings, for example: ('train', 'val').") def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + pad_multiple_of = pad_multiple_of def train_dataloader(self) -> DataLoader: input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path) diff --git a/litgpt/data/longform.py b/litgpt/data/longform.py index 34fcd29906..fb02ca06fe 100644 --- a/litgpt/data/longform.py +++ b/litgpt/data/longform.py @@ -36,6 +36,7 @@ class LongForm(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) @@ -44,11 +45,16 @@ def __post_init__(self) -> None: self.prompt_style = PromptStyle.from_name(self.prompt_style) def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = -1 if max_seq_length is None else max_seq_length + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: self.download_dir.mkdir(parents=True, exist_ok=True) @@ -80,7 +86,9 @@ def _dataloader(self, split: str) -> DataLoader: shuffle=(split == "train"), generator=torch.Generator().manual_seed(self.seed), num_workers=self.num_workers, - collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index), + collate_fn=get_sft_collate_fn( + max_seq_length=self.max_seq_length, ignore_index=self.ignore_index, pad_multiple_of=self.pad_multiple_of + ), ) diff --git a/litgpt/data/openwebtext.py b/litgpt/data/openwebtext.py index c6cc3151b3..6e5e3f1482 100644 --- a/litgpt/data/openwebtext.py +++ b/litgpt/data/openwebtext.py @@ -28,6 +28,7 @@ class OpenWebText(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, repr=False, init=False) batch_size: int = field(default=1, repr=False, init=False) seq_length: int = field(default=2048, repr=False, init=False) + pad_multiple_of: Optional[int] = field(default=None, repr=False, init=False) def __post_init__(self) -> None: # Could be a remote path (s3://) or a local path @@ -35,11 +36,16 @@ def __post_init__(self) -> None: self.data_path_val = str(self.data_path).rstrip("/") + "/val" def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = 2048 + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = 2048, + pad_multiple_of: Optional[int] = None, ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from datasets import Dataset, load_dataset diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py index 5989937669..8a954f109f 100644 --- a/litgpt/data/text_files.py +++ b/litgpt/data/text_files.py @@ -21,6 +21,7 @@ class TextFiles(DataModule): and provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ + train_data_path: Path """The path to the data directory used for training that contains .txt files""" val_data_path: Optional[Path] = None @@ -35,6 +36,7 @@ class TextFiles(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: self.out_path_train = self.train_data_path / "train" @@ -43,10 +45,17 @@ def __post_init__(self) -> None: else: self.out_path_val = Path(self.val_data_path) / "val" - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: int = -1, + pad_multiple_of: Optional[int] = None, + ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from litdata import optimize diff --git a/litgpt/data/tinyllama.py b/litgpt/data/tinyllama.py index d0267f2ff7..850e1e7914 100644 --- a/litgpt/data/tinyllama.py +++ b/litgpt/data/tinyllama.py @@ -27,6 +27,7 @@ class TinyLlama(DataModule): batch_size: int = field(init=False, repr=False, default=1) seq_length: int = field(init=False, repr=False, default=2048) + pad_multiple_of: Optional[int] = field(init=False, repr=False, default=None) def __post_init__(self): # Could be a remote path (s3://) or a local path @@ -35,10 +36,15 @@ def __post_init__(self): self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder" def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: Optional[int] = None, + pad_multiple_of: Optional[int] = None, ) -> None: self.batch_size = batch_size self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train): diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 632a015e44..bd13e23395 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -34,15 +34,23 @@ class TinyStories(DataModule): tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) + pad_multiple_of: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: self.data_path_train = self.data_path / "train" self.data_path_val = self.data_path / "val" - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + def connect( + self, + tokenizer: Optional[Tokenizer] = None, + batch_size: int = 1, + max_seq_length: int = -1, + pad_multiple_of: Optional[int] = None, + ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + self.pad_multiple_of = pad_multiple_of def prepare_data(self) -> None: from litdata import optimize diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index cf32ae501d..53487ede14 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -6,14 +6,16 @@ from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union +import warnings import lightning as L import torch from lightning.fabric.strategies import FSDPStrategy from torch.utils.data import DataLoader from torchmetrics import RunningMean +import yaml -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, TrainArgs, LongLoraArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate from litgpt.model import GPT, Block, Config @@ -25,6 +27,7 @@ choose_logger, chunked_cross_entropy, copy_config_files, + find_multiple, get_default_supported_precision, load_checkpoint, init_out_dir, @@ -52,6 +55,7 @@ def setup( max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), + longlora: LongLoraArgs = LongLoraArgs(use_longlora=False, n_groups=4, context_length=8192, trainable_params=""), optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, @@ -69,6 +73,7 @@ def setup( data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. + longlora: LongLoRA-related arguments. See ``litgpt.args.LongLoraArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. @@ -81,7 +86,9 @@ def setup( out_dir = init_out_dir(out_dir) check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") + config = Config.from_file( + checkpoint_dir / "model_config.yaml", use_longlora=longlora.use_longlora, longlora_n_groups=longlora.n_groups + ) precision = precision or get_default_supported_precision(training=True) logger = choose_logger( @@ -100,7 +107,7 @@ def setup( strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger) - fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, longlora) def main( @@ -115,11 +122,27 @@ def main( train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], + longlora: LongLoraArgs, ) -> None: validate_args(train, eval) + if resume is True: + resume = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1]))) + if resume: + with open(resume.parent / "hyperparameters.yaml", "r") as f: + hyperparams = yaml.safe_load(f) + longlora_cfg = hyperparams.get("longlora", None) + if longlora_cfg is not None: + longlora.use_longlora = longlora_cfg.get("use_longlora", False) + longlora.n_groups = longlora_cfg.get("n_groups", longlora.n_groups) + longlora.context_length = longlora_cfg.get("context_length", longlora.context_length) + config.use_longlora = longlora.use_longlora + config.longlora_n_groups = longlora.n_groups + validate_longlora_args(config, longlora) tokenizer = Tokenizer(checkpoint_dir) - train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + train_dataloader, val_dataloader = get_dataloaders( + fabric, data, tokenizer, train, pad_multiple_of=longlora.n_groups if longlora.use_longlora else None + ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -130,6 +153,15 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): + if longlora.use_longlora and longlora.context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora.context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora.context_length / old_block_size + fabric.print(f"The model context length has been increased from {old_block_size} to {config.block_size}") + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") @@ -141,8 +173,6 @@ def main( scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) state = {"model": model, "optimizer": optimizer, "scheduler": scheduler, "iter_num": 0, "step_count": 0} - if resume is True: - resume = max(out_dir.rglob("step-*/*.pth"), key=(lambda p: int(p.parent.name.split("-")[1]))) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) @@ -150,7 +180,7 @@ def main( load_checkpoint(fabric, state["model"], checkpoint_path) train_time = time.perf_counter() - fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir, out_dir, train, eval, data) + fit(fabric, state, train_dataloader, val_dataloader, devices, resume, checkpoint_dir,out_dir, train, eval, longlora, data) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") @@ -183,6 +213,7 @@ def fit( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, data: DataModule, ) -> None: model = state["model"] @@ -190,7 +221,9 @@ def fit( scheduler = state["scheduler"] tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -338,9 +371,9 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): def get_dataloaders( - fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs + fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, pad_multiple_of: Optional[int] = None ) -> Tuple[DataLoader, DataLoader]: - data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length, pad_multiple_of=pad_multiple_of) with fabric.rank_zero_first(): data.prepare_data() data.setup() @@ -374,3 +407,18 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) + + +def validate_longlora_args(config: Config, longlora: LongLoraArgs): + if longlora.use_longlora: + if longlora.context_length <= config.block_size: + warnings.warn( + f"LongLora is disabled because the LongLora context length ({longlora.context_length}) " + f"is less than the model original block size {config.block_size}. " + ) + longlora.use_longlora = False + elif longlora.context_length % longlora.n_groups != 0: + raise ValueError( + f"LongLora context length ({longlora.context_length}) must be a multiple of the number of groups " + f"({longlora.n_groups})." + ) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 5f5e12dcf9..0e49c40d31 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -1,11 +1,13 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import dataclasses +from functools import partial import math import os import time from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union +import warnings import lightning as L import torch @@ -15,10 +17,17 @@ from torch.utils.data import DataLoader from torchmetrics import RunningMean -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, LongLoraArgs, TrainArgs from litgpt.data import Alpaca, DataModule from litgpt.generate.base import generate -from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable +from litgpt.lora import ( + GPT, + Block, + Config, + longlora_filter, + lora_filter, + mark_only_lora_as_trainable, +) from litgpt.prompts import save_prompt_style from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer @@ -28,6 +37,7 @@ choose_logger, chunked_cross_entropy, copy_config_files, + find_multiple, get_default_supported_precision, load_checkpoint, init_out_dir, @@ -65,6 +75,9 @@ def setup( max_seq_length=None, ), eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), + longlora: LongLoraArgs = LongLoraArgs( + use_longlora=False, n_groups=4, context_length=8192, trainable_params="wte,norm,ln" + ), optimizer: Union[str, Dict] = "AdamW", logger_name: Literal["wandb", "tensorboard", "csv"] = "csv", seed: int = 1337, @@ -90,6 +103,7 @@ def setup( data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details. eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details. + longlora: LongLoRA-related arguments. See ``litgpt.args.LongLoraArgs`` for details. optimizer: An optimizer name (such as "AdamW") or config. logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. @@ -112,6 +126,8 @@ def setup( lora_projection=lora_projection, lora_mlp=lora_mlp, lora_head=lora_head, + use_longlora=longlora.use_longlora, + longlora_n_groups=longlora.n_groups, ) precision = precision or get_default_supported_precision(training=True) @@ -142,7 +158,7 @@ def setup( strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer) + fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, longlora) def main( @@ -156,11 +172,15 @@ def main( train: TrainArgs, eval: EvalArgs, optimizer: Union[str, Dict], + longlora: LongLoraArgs, ) -> None: validate_args(train, eval) + validate_longlora_args(config, longlora) tokenizer = Tokenizer(checkpoint_dir) - train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) + train_dataloader, val_dataloader = get_dataloaders( + fabric, data, tokenizer, train, pad_multiple_of=longlora.n_groups if longlora.use_longlora else None + ) steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices) lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float("inf"))) @@ -171,9 +191,26 @@ def main( checkpoint_path = checkpoint_dir / "lit_model.pth" with fabric.init_module(empty_init=(devices > 1)): + if longlora.use_longlora and longlora.context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora.context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora.context_length / old_block_size + fabric.print(f"The model context length has been increased from {old_block_size} to {config.block_size}") + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) + model = GPT(config) mark_only_lora_as_trainable(model) + # Let other layers be trainable + if longlora.use_longlora and longlora.trainable_params != "": + trainable_params = set(longlora.trainable_params.strip().split(",")) + for n, p in model.named_parameters(): + if any(trainable_p_name in n for trainable_p_name in trainable_params): + p.requires_grad = True + fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}") fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}") @@ -203,6 +240,7 @@ def main( out_dir, train, eval, + longlora, data, ) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") @@ -218,7 +256,7 @@ def main( # Save the final LoRA checkpoint at the end of training save_path = out_dir / "final" / "lit_model.pth.lora" save_path.parent.mkdir(parents=True, exist_ok=True) - save_lora_checkpoint(fabric, model, save_path) + save_lora_checkpoint(fabric, model, save_path, longlora=longlora) if fabric.global_rank == 0: # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) @@ -239,11 +277,15 @@ def fit( out_dir: Path, train: TrainArgs, eval: EvalArgs, + longlora: LongLoraArgs, data: DataModule, ) -> None: tokenizer = Tokenizer(checkpoint_dir) longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) - model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + longest_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) + if longlora.use_longlora: + longest_seq_length = find_multiple(longest_seq_length, longlora.n_groups) + model.max_seq_length = longest_seq_length fabric.print( f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" @@ -331,7 +373,7 @@ def fit( if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0: checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.lora" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) - save_lora_checkpoint(fabric, model, checkpoint_file) + save_lora_checkpoint(fabric, model, checkpoint_file, longlora=longlora) if fabric.global_rank == 0: copy_config_files(checkpoint_dir, checkpoint_file.parent) save_hyperparameters(setup, checkpoint_file.parent) @@ -385,9 +427,14 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): def get_dataloaders( - fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs + fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, pad_multiple_of: Optional[int] = None ) -> Tuple[DataLoader, DataLoader]: - data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length) + data.connect( + tokenizer=tokenizer, + batch_size=train.micro_batch_size, + max_seq_length=train.max_seq_length, + pad_multiple_of=pad_multiple_of + ) with fabric.rank_zero_first(): data.prepare_data() data.setup() @@ -405,9 +452,22 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: return longest_seq_length, longest_seq_ix -def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None: +def save_lora_checkpoint(fabric: L.Fabric, model: GPT, file_path: Path, longlora: LongLoraArgs) -> None: fabric.print(f"Saving LoRA weights to {str(file_path)!r}") - fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) + fabric.save( + file_path, + {"model": model}, + filter={ + "model": ( + lora_filter + if not longlora.use_longlora + else partial( + longlora_filter, + additional_weights=longlora.trainable_params.strip().split(","), + ) + ) + }, + ) def validate_args(train: TrainArgs, eval: EvalArgs) -> None: @@ -426,3 +486,18 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues.append(f"{__file__} requires either epochs or max_steps to be set. This is set in {train}") if issues: raise ValueError("\n".join(issues)) + + +def validate_longlora_args(config: Config, longlora: LongLoraArgs): + if longlora.use_longlora: + if longlora.context_length <= config.block_size: + warnings.warn( + f"LongLora is disabled because the LongLora context length ({longlora.context_length}) " + f"is less than the model original block size {config.block_size}. " + ) + longlora.use_longlora = False + elif longlora.context_length % longlora.n_groups != 0: + raise ValueError( + f"LongLora context length ({longlora.context_length}) must be a multiple of the number of groups " + f"({longlora.n_groups})." + ) \ No newline at end of file diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 50d6397de6..7415ba4ab0 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -9,6 +9,7 @@ import torch import torch._dynamo.config import torch._inductor.config +import yaml from lightning.fabric.plugins import BitsandbytesPrecision from litgpt import GPT, Config, PromptStyle, Tokenizer @@ -190,6 +191,16 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) + else: + use_longlora = False checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -206,6 +217,17 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): + if use_longlora and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.block_size}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index 3ac060a3b4..2f95e04f7d 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -8,6 +8,7 @@ import lightning as L import torch from lightning.fabric.plugins import BitsandbytesPrecision +import yaml from litgpt import GPT, Config, PromptStyle, Tokenizer from litgpt.generate.base import generate @@ -75,6 +76,16 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) + else: + use_longlora = False checkpoint_path = finetuned_path @@ -91,6 +102,17 @@ def main( fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr) t0 = time.perf_counter() with fabric.init_module(empty_init=True): + if use_longlora and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.block_size}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) with fabric.init_tensor(): diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index d3d5250c30..ec65395336 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -16,6 +16,7 @@ from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from typing_extensions import Type +import yaml import litgpt.generate.base as generate_base from litgpt import GPT, Config, Tokenizer @@ -174,6 +175,16 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) + else: + use_longlora = False checkpoint_path = checkpoint_dir / "lit_model.pth" @@ -188,6 +199,17 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): + if use_longlora and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.block_size}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index d8439a220e..0d6d166f48 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -14,6 +14,7 @@ from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only from torch.distributed._functional_collectives import all_reduce +import yaml import litgpt.generate.base as generate_base from litgpt import GPT, Config, Tokenizer @@ -153,6 +154,16 @@ def main( check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") + if (hyperparams_dir := (checkpoint_dir / "hyperparameters.yaml")).is_file(): + with open(hyperparams_dir, "r", encoding="utf-8") as hparams_file: + hparams = yaml.safe_load(hparams_file) + longlora_cfg = hparams.get("longlora", None) + use_longlora = False + if longlora_cfg is not None: + use_longlora = longlora_cfg.get("use_longlora", False) + longlora_context_length = longlora_cfg.get("context_length", config.block_size) + else: + use_longlora = False model_file = "lit_model.pth" checkpoint_path = checkpoint_dir / model_file @@ -168,6 +179,17 @@ def main( # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert # still, use init_tensor for the precision with fabric.init_tensor(), torch.device("meta"): + if use_longlora and longlora_context_length > config.block_size: + old_block_size = config.block_size + config.block_size = longlora_context_length + old_rope_condense_ratio = config.rope_condense_ratio + config.rope_condense_ratio = longlora_context_length / old_block_size + fabric.print( + f"The model context length has been increased from {old_block_size} to {config.block_size}" + ) + fabric.print( + f"The 'rope_condense_ratio' has been adapted from {old_rope_condense_ratio} to {config.rope_condense_ratio}" + ) model = GPT(config) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/lora.py b/litgpt/lora.py index 7c4ae423e0..7ed7ec2c69 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -45,7 +45,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union import torch import torch.nn as nn @@ -467,6 +467,10 @@ def lora_filter(key: str, value: Any) -> bool: return "lora_" in key +def longlora_filter(key: str, value: Any, additional_weights: Sequence[str] = ["lora_"]) -> bool: + return any(x in key for x in additional_weights + ["lora_"]) + + @dataclass class Config(BaseConfig): """ diff --git a/litgpt/model.py b/litgpt/model.py index fe71c60b80..9e0c131101 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -212,6 +212,15 @@ def forward( ) -> torch.Tensor: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + if input_pos is None and self.config.use_longlora: + if T % self.config.longlora_n_groups != 0: + raise ValueError( + f"sequence length {T} should be divisible by the number of groups {self.config.longlora_n_groups}." + ) + longlora_group_size = T // self.config.longlora_n_groups + else: + longlora_group_size = 0 + qkv = self.attn(x) # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) @@ -243,10 +252,23 @@ def forward( if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + elif longlora_group_size > 0: + q = roll_and_group(q, B, T, longlora_group_size, q.shape[1], self.config.head_size) + k = roll_and_group(k, B, T, longlora_group_size, k.shape[1], self.config.head_size) + v = roll_and_group(v, B, T, longlora_group_size, v.shape[1], self.config.head_size) y = self.scaled_dot_product_attention(q, k, v, mask) - y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + if input_pos is None and longlora_group_size > 0: + # shift back and unroll + n_heads = y.shape[2] + y = y.reshape(B, T, n_heads, self.config.head_size) # (B, T, nh, hs) + y0, y1 = y.split(n_heads // 2, dim=2) + y1 = y1.roll(longlora_group_size // 2, dims=1) + y = torch.cat((y0, y1), dim=2) + + # re-assemble all head outputs side by side + y = y.reshape(B, T, self.config.head_size * self.config.n_head) # output projection return self.proj(y) @@ -284,6 +306,21 @@ def build_kv_cache( return KVCache(k_shape, v_shape, device=device, dtype=dtype) +def roll_and_group( + qkv: torch.Tensor, bsz: int, q_len: int, group_size: int, num_heads: int, head_dim: int +) -> torch.Tensor: + # Split, roll and recompose to avoid the following error: + # RuntimeError: Output 0 of SliceBackward0 is a view and is being modified inplace. + # This view is the output of a function that returns multiple views. + # Such functions do not allow the output views to be modified inplace. + # You should replace the inplace operation by an out-of-place one. + qkv0, qkv1 = qkv.split(num_heads // 2, dim=1) + qkv1 = qkv1.roll(-group_size // 2, dims=2) + qkv = torch.cat((qkv0, qkv1), dim=1) + qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2) + return qkv + + class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() diff --git a/tests/test_full.py b/tests/test_full.py index 74bc10f22e..993b50bb7a 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -12,6 +12,7 @@ import litgpt.finetune.full as module from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca +from litgpt.utils import CLI @mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) @@ -69,3 +70,90 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs assert logs.count("(step)") == 2 assert out_dir / "step-000008" in set(out_dir.iterdir()) + + +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_full_longlora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): + model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) + (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) + monkeypatch.setattr(module, "load_checkpoint", Mock()) + + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + + out_dir = tmp_path / "out" + setup_kwargs = dict( + data=Alpaca(download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0), + checkpoint_dir=fake_checkpoint_dir, + out_dir=out_dir, + precision="32-true", + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), + eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), + ) + stdout = StringIO() + with redirect_stdout(stdout), mock.patch( + # Needed to save_hyperparameters function saves correctly LongLora params to be used when resuming + "sys.argv", + [ + "full.py", + "--data=litgpt.data.Alpaca", + "--data.download_dir=" + str(alpaca_path.parent), + "--data.file_name=" + str(alpaca_path.name), + "--data.val_split_fraction=0.5", + "--data.num_workers=0", + "--checkpoint_dir=" + str(fake_checkpoint_dir), + "--out_dir=" + str(out_dir), + "--precision=32-true", + "--train.global_batch_size=1", + "--train.save_interval=2", + "--train.epochs=1", + "--train.max_steps=6", + "--train.micro_batch_size=1", + "--eval.interval=2", + "--eval.max_iters=2", + "--eval.max_new_tokens=1", + "--longlora.use_longlora=True", + "--longlora.n_groups=4", + "--longlora.context_length=256", + ], + ): + CLI(module.setup) + + out_dir_contents = set(os.listdir(out_dir)) + checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} + assert checkpoint_dirs.issubset(out_dir_contents) + assert all((out_dir / p).is_dir() for p in checkpoint_dirs) + for checkpoint_dir in checkpoint_dirs: + assert set(os.listdir(out_dir / checkpoint_dir)) == { + "lit_model.pth", + "model_config.yaml", + "tokenizer_config.json", + "tokenizer.json", + "hyperparameters.yaml", + "prompt_style.yaml", + } + assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() + + logs = stdout.getvalue() + assert logs.count("(step)") == 6 + assert logs.count("val loss") == 4 # 3 validations + 1 final validation + assert logs.count("Final evaluation") == 1 + assert "of trainable parameters: 1,888" in logs + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + + # Resume training and do 2 steps more + setup_kwargs["train"].max_steps = 8 + setup_kwargs["resume"] = True + stdout = StringIO() + with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]): + module.setup(**setup_kwargs) + logs = stdout.getvalue() + assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs + assert logs.count("(step)") == 2 + assert out_dir / "step-000008" in set(out_dir.iterdir()) + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + diff --git a/tests/test_lora.py b/tests/test_lora.py index 8f4edba90a..3ef29e181d 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -19,7 +19,7 @@ import litgpt.config as config_module import litgpt.finetune.lora as module -from litgpt.args import EvalArgs, TrainArgs +from litgpt.args import EvalArgs, LongLoraArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights @@ -226,6 +226,70 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert "of trainable parameters: 512" in logs +@mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}) +def test_longlora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): + model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8) + (fake_checkpoint_dir / "model_config.yaml").write_text(yaml.dump(model_config)) + monkeypatch.setattr(module, "load_checkpoint", Mock()) + monkeypatch.setattr(module, "merge_lora", Mock()) + + tokenizer_mock = Mock() + tokenizer_mock.return_value = tokenizer_mock + tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1]) + monkeypatch.setattr(module, "Tokenizer", tokenizer_mock) + + out_dir = tmp_path / "out" + stdout = StringIO() + with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]): + module.setup( + data=Alpaca( + download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0 + ), + checkpoint_dir=fake_checkpoint_dir, + out_dir=out_dir, + precision="32-true", + train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1), + eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), + longlora=LongLoraArgs(use_longlora=True, context_length=256), + ) + + out_dir_contents = set(os.listdir(out_dir)) + checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"} + assert checkpoint_dirs.issubset(out_dir_contents) + assert all((out_dir / p).is_dir() for p in checkpoint_dirs) + for checkpoint_dir in checkpoint_dirs: + assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { + "lit_model.pth.lora", + "model_config.yaml", + "tokenizer_config.json", + "tokenizer.json", + "hyperparameters.yaml", + "prompt_style.yaml", + } + lora_ckpt = torch.load(out_dir / checkpoint_dir / "lit_model.pth.lora")["model"] + lora_ckpt_keys = lora_ckpt.keys() + assert all( + param in lora_ckpt_keys + for param in [ + "transformer.wte.weight", + "transformer.h.0.norm_1.weight", + "transformer.h.0.norm_2.weight", + "transformer.h.1.norm_1.weight", + "transformer.h.1.norm_2.weight", + "transformer.ln_f.weight", + ] + ) + assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file() + + logs = stdout.getvalue() + assert logs.count("(step)") == 6 + assert logs.count("val loss") == 4 # 3 validations + 1 final validation + assert logs.count("Final evaluation") == 1 + assert "of trainable parameters: 656" in logs + assert "The model context length has been increased from 128 to 256" in logs + assert "The 'rope_condense_ratio' has been adapted from 1 to 2.0" in logs + + def test_lora_init_when_linear_overridden(): class MyLinear(torch.nn.Linear): def __init__(self, *args, **kwargs):