Skip to content

Commit 4b2e02e

Browse files
carmoccarasbt
authored andcommitted
Automated formatting (#1126)
1 parent 85bce0d commit 4b2e02e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+353
-324
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
 
2020

21-
⚡ LitGPT is a hackable [implementation](litgpt/model.py) of state-of-the-art open-source large language models released under the **Apache 2.0 license**.
21+
⚡ LitGPT is a hackable [implementation](litgpt/model.py) of state-of-the-art open-source large language models released under the **Apache 2.0 license**.
2222

2323
 
2424
## LitGPT supports
@@ -141,7 +141,7 @@ For added convenience, you can also manually override config file setting via th
141141

142142

143143
```bash
144-
litgpt finetune lora
144+
litgpt finetune lora
145145
--config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \
146146
--lora_r 4
147147
```
@@ -150,7 +150,7 @@ You can browse the available configuration files [here](https://github.com/Light
150150

151151
 
152152

153-
> [!TIP]
153+
> [!TIP]
154154
> **Run large models on smaller consumer devices:**
155155
> We support 4-bit quantization (as in QLoRA), (bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq) and 8-bit quantization (bnb.int8) for inference by following [this guide](tutorials/quantize.md).
156156
@@ -314,7 +314,7 @@ We welcome all individual contributors, regardless of their level of experience
314314
315315
 
316316
317-
> [!TIP]
317+
> [!TIP]
318318
> Unsure about contributing? Check out our [How to Contribute to LitGPT](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) guide.
319319
320320
If you have general questions about building with LitGPT, please [join our Discord](https://discord.gg/VptPCZkGNa).

litgpt/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
def _new_parser(**kwargs: Any) -> "ArgumentParser":
32-
from jsonargparse import ArgumentParser, ActionConfigFile
32+
from jsonargparse import ActionConfigFile, ArgumentParser
3333

3434
parser = ArgumentParser(**kwargs)
3535
parser.add_argument(
@@ -80,7 +80,7 @@ def main() -> None:
8080
"merge_lora": {"help": "Merges the LoRA weights with the base model.", "fn": merge_lora_fn},
8181
}
8282

83-
from jsonargparse import set_docstring_parse_options, set_config_read_mode
83+
from jsonargparse import set_config_read_mode, set_docstring_parse_options
8484

8585
set_docstring_parse_options(attribute_docstrings=True)
8686
set_config_read_mode(urls_enabled=True)

litgpt/chat/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import torch
1010
from lightning.fabric.plugins import BitsandbytesPrecision
1111

12-
from litgpt.generate.base import next_token
1312
from litgpt import GPT, Config, PromptStyle, Tokenizer
14-
from litgpt.prompts import load_prompt_style, has_prompt_style
13+
from litgpt.generate.base import next_token
14+
from litgpt.prompts import has_prompt_style, load_prompt_style
1515
from litgpt.scripts.merge_lora import merge_lora
1616
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint
1717

@@ -159,7 +159,9 @@ def main(
159159
model = fabric.setup_module(model)
160160

161161
tokenizer = Tokenizer(checkpoint_dir)
162-
prompt_style = load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
162+
prompt_style = (
163+
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
164+
)
163165
stop_tokens = prompt_style.stop_tokens(tokenizer)
164166

165167
print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")

litgpt/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

3-
import yaml
43
from copy import deepcopy
54
from dataclasses import dataclass, field
65
from pathlib import Path
76
from typing import Any, Literal, Optional, Type, Union
87

98
import torch
9+
import yaml
1010
from typing_extensions import Self
1111

1212
import litgpt.model

litgpt/data/alpaca.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ def __post_init__(self) -> None:
5050
self.prompt_style = PromptStyle.from_name(self.prompt_style)
5151

5252
def connect(
53-
self,
54-
tokenizer: Optional[Tokenizer] = None,
55-
batch_size: int = 1,
56-
max_seq_length: Optional[int] = None
53+
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
5754
) -> None:
5855
self.tokenizer = tokenizer
5956
self.batch_size = batch_size
@@ -71,7 +68,7 @@ def setup(self, stage: str = "") -> None:
7168
train_data, test_data = random_split(
7269
data,
7370
[1.0 - self.val_split_fraction, self.val_split_fraction],
74-
generator=torch.Generator().manual_seed(self.seed)
71+
generator=torch.Generator().manual_seed(self.seed),
7572
)
7673
train_data, test_data = list(train_data), list(test_data)
7774

@@ -99,7 +96,7 @@ def train_dataloader(self) -> DataLoader:
9996
shuffle=True,
10097
generator=torch.Generator().manual_seed(self.seed),
10198
num_workers=self.num_workers,
102-
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
99+
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
103100
)
104101

105102
def val_dataloader(self) -> DataLoader:
@@ -108,7 +105,7 @@ def val_dataloader(self) -> DataLoader:
108105
batch_size=self.batch_size,
109106
shuffle=False,
110107
num_workers=self.num_workers,
111-
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
108+
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
112109
)
113110

114111

litgpt/data/alpaca_2k.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,3 @@ def setup(self, stage: str = "") -> None:
5050
mask_prompt=self.mask_prompt,
5151
ignore_index=self.ignore_index,
5252
)
53-

litgpt/data/base.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@ class DataModule(LightningDataModule):
1717

1818
@abstractmethod
1919
def connect(
20-
self,
21-
tokenizer: Optional[Tokenizer] = None,
22-
batch_size: int = 1,
23-
max_seq_length: Optional[int] = None
20+
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
2421
) -> None:
2522
"""All settings that can't be determined at the time of instantiation need to be passed through here
2623
before any dataloaders can be accessed.
@@ -53,6 +50,7 @@ class SFTDataset(Dataset):
5350
labels: Same as input_ids, unless ``mask_prompt=True`` in which case the 'prompt' part is replaced with
5451
the ``ignore_index``.
5552
"""
53+
5654
def __init__(
5755
self,
5856
data: List[Dict[str, str]],
@@ -61,7 +59,7 @@ def __init__(
6159
max_seq_length: int = -1,
6260
mask_prompt: bool = True,
6361
ignore_index: int = -100,
64-
transform: Optional[Callable[[Any], Any]] = None
62+
transform: Optional[Callable[[Any], Any]] = None,
6563
) -> None:
6664
self.data = data
6765
self.tokenizer = tokenizer
@@ -84,9 +82,7 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]:
8482
prompt_and_response = prompt + example["output"]
8583
encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length)
8684
encoded_prompt_and_response = self.tokenizer.encode(
87-
prompt_and_response,
88-
eos=True,
89-
max_length=self.max_seq_length,
85+
prompt_and_response, eos=True, max_length=self.max_seq_length
9086
)
9187

9288
# The labels are the full prompt with response, but with the prompt masked out

litgpt/data/deita.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ def __post_init__(self) -> None:
4545
self.prompt_style = PromptStyle.from_name(self.prompt_style)
4646

4747
def connect(
48-
self,
49-
tokenizer: Optional[Tokenizer] = None,
50-
batch_size: int = 1,
51-
max_seq_length: Optional[int] = None
48+
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
5249
) -> None:
5350
self.tokenizer = tokenizer
5451
self.batch_size = batch_size
@@ -99,7 +96,7 @@ def val_dataloader(self) -> DataLoader:
9996
batch_size=self.batch_size,
10097
shuffle=False,
10198
num_workers=self.num_workers,
102-
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
99+
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
103100
)
104101

105102

litgpt/data/dolly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def setup(self, stage: str = "") -> None:
5050
train_data, test_data = random_split(
5151
data,
5252
[1.0 - self.val_split_fraction, self.val_split_fraction],
53-
generator=torch.Generator().manual_seed(self.seed)
53+
generator=torch.Generator().manual_seed(self.seed),
5454
)
5555
train_data, test_data = list(train_data), list(test_data)
5656

litgpt/data/flan.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@ def __post_init__(self):
5959
self.subsets = list(supported_subsets)
6060

6161
def connect(
62-
self,
63-
tokenizer: Optional[Tokenizer] = None,
64-
batch_size: int = 1,
65-
max_seq_length: Optional[int] = None
62+
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
6663
) -> None:
6764
self.tokenizer = tokenizer
6865
self.batch_size = batch_size
@@ -103,7 +100,7 @@ def _dataloader(self, split: str) -> DataLoader:
103100
shuffle=(split == "train"),
104101
generator=torch.Generator().manual_seed(self.seed),
105102
num_workers=self.num_workers,
106-
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
103+
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),
107104
)
108105

109106

0 commit comments

Comments
 (0)