Skip to content

Commit 85f7f7b

Browse files
authored
Merge pull request #116 from georgian-io/feature/format-lint-action
[Workflow] Automatically Lint and Format PRs
2 parents 613a30e + b76f810 commit 85f7f7b

File tree

17 files changed

+198
-313
lines changed

17 files changed

+198
-313
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: Ruff
2+
on: pull_request
3+
jobs:
4+
lint:
5+
name: Lint, Format, and Commit
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: actions/checkout@v4
9+
- uses: chartboost/ruff-action@v1
10+
name: Lint
11+
with:
12+
version: 0.3.5
13+
args: "check --output-format=full --statistics"
14+
- uses: chartboost/ruff-action@v1
15+
name: Format
16+
with:
17+
version: 0.3.5
18+
args: "format --check"

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,10 @@ If you would like to contribute to this project, we recommend following the "for
255255
5. Submit a **Pull request** so that we can review your changes
256256

257257
NOTE: Be sure to merge the latest from "upstream" before making a pull request!
258+
259+
### Checklist Before Pull Request (Optional)
260+
261+
1. Use `ruff check --fix` to check and fix lint errors
262+
2. Use `ruff format` to apply formatting
263+
264+
NOTE: Ruff linting and formatting checks are done when PR is raised via Git Action. Before raising a PR, it is a good practice to check and fix lint errors, as well as apply formatting.

llmtune/data/dataset_generator.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
2-
from os.path import join, exists
2+
import pickle
3+
import re
34
from functools import partial
5+
from os.path import exists, join
46
from typing import Tuple, Union
5-
import pickle
67

7-
import re
88
from datasets import Dataset
99

1010
from llmtune.data.ingestor import Ingestor, get_ingestor
@@ -61,12 +61,8 @@ def _format_one_prompt(self, example, is_test: bool = False):
6161
return example
6262

6363
def _format_prompts(self):
64-
self.dataset["train"] = self.dataset["train"].map(
65-
partial(self._format_one_prompt, is_test=False)
66-
)
67-
self.dataset["test"] = self.dataset["test"].map(
68-
partial(self._format_one_prompt, is_test=True)
69-
)
64+
self.dataset["train"] = self.dataset["train"].map(partial(self._format_one_prompt, is_test=False))
65+
self.dataset["test"] = self.dataset["test"].map(partial(self._format_one_prompt, is_test=True))
7066

7167
def get_dataset(self) -> Tuple[Dataset, Dataset]:
7268
self._train_test_split()

llmtune/data/ingestor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1+
import csv
12
from abc import ABC, abstractmethod
2-
from functools import partial
33

44
import ijson
5-
import csv
6-
from datasets import Dataset, load_dataset, concatenate_datasets
5+
from datasets import Dataset, concatenate_datasets, load_dataset
76

87

98
def get_ingestor(data_type: str):
@@ -14,9 +13,7 @@ def get_ingestor(data_type: str):
1413
elif data_type == "huggingface":
1514
return HuggingfaceIngestor
1615
else:
17-
raise ValueError(
18-
f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}"
19-
)
16+
raise ValueError(f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}")
2017

2118

2219
class Ingestor(ABC):

llmtune/finetune/lora.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
1-
from os.path import join, exists
2-
from typing import Tuple
3-
4-
import torch
1+
from os.path import join
52

63
import bitsandbytes as bnb
4+
import torch
75
from datasets import Dataset
6+
from peft import (
7+
LoraConfig,
8+
get_peft_model,
9+
prepare_model_for_kbit_training,
10+
)
811
from transformers import (
9-
AutoTokenizer,
1012
AutoModelForCausalLM,
11-
BitsAndBytesConfig,
12-
TrainingArguments,
1313
AutoTokenizer,
14+
BitsAndBytesConfig,
1415
ProgressCallback,
15-
)
16-
from peft import (
17-
prepare_model_for_kbit_training,
18-
get_peft_model,
19-
LoraConfig,
16+
TrainingArguments,
2017
)
2118
from trl import SFTTrainer
22-
from rich.console import Console
2319

24-
25-
from llmtune.pydantic_models.config_model import Config
26-
from llmtune.utils.save_utils import DirectoryHelper
2720
from llmtune.finetune.generics import Finetune
21+
from llmtune.pydantic_models.config_model import Config
2822
from llmtune.ui.rich_ui import RichUI
23+
from llmtune.utils.save_utils import DirectoryHelper
2924

3025

3126
class LoRAFinetune(Finetune):
@@ -99,9 +94,7 @@ def _inject_lora(self):
9994
self.model = get_peft_model(self.model, self._lora_config)
10095

10196
if not self.config.accelerate:
102-
self.optimizer = bnb.optim.Adam8bit(
103-
self.model.parameters(), lr=self._training_args.learning_rate
104-
)
97+
self.optimizer = bnb.optim.Adam8bit(self.model.parameters(), lr=self._training_args.learning_rate)
10598
self.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer)
10699
if self.config.accelerate:
107100
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
@@ -132,7 +125,7 @@ def finetune(self, train_dataset: Dataset):
132125
**self._sft_args.model_dump(),
133126
)
134127

135-
trainer_stats = self._trainer.train()
128+
self._trainer.train()
136129

137130
def save_model(self) -> None:
138131
self._trainer.model.save_pretrained(self._weights_path)

llmtune/inference/lora.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
1+
import csv
12
import os
23
from os.path import join
34
from threading import Thread
4-
import csv
55

6-
from transformers import TextIteratorStreamer
7-
from rich.text import Text
6+
import torch
87
from datasets import Dataset
9-
from transformers import AutoTokenizer, BitsAndBytesConfig
108
from peft import AutoPeftModelForCausalLM
11-
import torch
12-
9+
from rich.text import Text
10+
from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
1311

14-
from llmtune.pydantic_models.config_model import Config
15-
from llmtune.utils.save_utils import DirectoryHelper
1612
from llmtune.inference.generics import Inference
13+
from llmtune.pydantic_models.config_model import Config
1714
from llmtune.ui.rich_ui import RichUI
15+
from llmtune.utils.save_utils import DirectoryHelper
1816

1917

2018
# TODO: Add type hints please!
@@ -35,9 +33,7 @@ def __init__(
3533
self.device_map = self.config.model.device_map
3634
self._weights_path = dir_helper.save_paths.weights
3735

38-
self.model, self.tokenizer = self._get_merged_model(
39-
dir_helper.save_paths.weights
40-
)
36+
self.model, self.tokenizer = self._get_merged_model(dir_helper.save_paths.weights)
4137

4238
def _get_merged_model(self, weights_path: str):
4339
# purge VRAM
@@ -47,20 +43,14 @@ def _get_merged_model(self, weights_path: str):
4743
dtype = (
4844
torch.float16
4945
if self.config.training.training_args.fp16
50-
else (
51-
torch.bfloat16
52-
if self.config.training.training_args.bf16
53-
else torch.float32
54-
)
46+
else (torch.bfloat16 if self.config.training.training_args.bf16 else torch.float32)
5547
)
5648

5749
self.model = AutoPeftModelForCausalLM.from_pretrained(
5850
weights_path,
5951
torch_dtype=dtype,
6052
device_map=self.device_map,
61-
quantization_config=(
62-
BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())
63-
),
53+
quantization_config=(BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())),
6454
)
6555

6656
"""TODO: figure out multi-gpu
@@ -70,9 +60,7 @@ def _get_merged_model(self, weights_path: str):
7060

7161
model = self.model.merge_and_unload()
7262

73-
tokenizer = AutoTokenizer.from_pretrained(
74-
self._weights_path, device_map=self.device_map
75-
)
63+
tokenizer = AutoTokenizer.from_pretrained(self._weights_path, device_map=self.device_map)
7664

7765
return model, tokenizer
7866

@@ -83,13 +71,11 @@ def infer_all(self):
8371

8472
# inference loop
8573
for idx, (prompt, label) in enumerate(zip(prompts, labels)):
86-
RichUI.inference_ground_truth_display(
87-
f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label
88-
)
74+
RichUI.inference_ground_truth_display(f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label)
8975

9076
try:
9177
result = self.infer_one(prompt)
92-
except:
78+
except Exception:
9379
continue
9480
results.append((prompt, label, result))
9581

@@ -103,9 +89,7 @@ def infer_all(self):
10389
writer.writerow(row)
10490

10591
def infer_one(self, prompt: str) -> str:
106-
input_ids = self.tokenizer(
107-
prompt, return_tensors="pt", truncation=True
108-
).input_ids.cuda()
92+
input_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
10993

11094
# stream processor
11195
streamer = TextIteratorStreamer(
@@ -115,9 +99,7 @@ def infer_one(self, prompt: str) -> str:
11599
timeout=60, # 60 sec timeout for generation; to handle OOM errors
116100
)
117101

118-
generation_kwargs = dict(
119-
input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump()
120-
)
102+
generation_kwargs = dict(input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump())
121103

122104
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
123105
thread.start()

llmtune/pydantic_models/config_model.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
1-
from typing import Literal, Union, List, Dict, Optional
2-
from pydantic import BaseModel, FilePath, validator, Field
3-
4-
from huggingface_hub.utils import validate_repo_id
1+
from typing import List, Literal, Optional, Union
52

63
import torch
4+
from pydantic import BaseModel, Field, FilePath, validator
5+
76

87
# TODO: Refactor this into multiple files...
98
HfModelPath = str
109

10+
1111
class QaConfig(BaseModel):
12-
llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected")
13-
12+
llm_tests: Optional[List[str]] = Field([], description="list of tests that needs to be connected")
13+
1414

1515
class DataConfig(BaseModel):
16-
file_type: Literal["json", "csv", "huggingface"] = Field(
17-
None, description="File type"
18-
)
19-
path: Union[FilePath, HfModelPath] = Field(
20-
None, description="Path to the file or HuggingFace model"
21-
)
22-
prompt: str = Field(
23-
None, description="Prompt for the model. Use {} brackets for column name"
24-
)
16+
file_type: Literal["json", "csv", "huggingface"] = Field(None, description="File type")
17+
path: Union[FilePath, HfModelPath] = Field(None, description="Path to the file or HuggingFace model")
18+
prompt: str = Field(None, description="Prompt for the model. Use {} brackets for column name")
2519
prompt_stub: str = Field(
2620
None,
2721
description="Stub for the prompt; this is injected during training. Use {} brackets for column name",
@@ -48,9 +42,7 @@ class DataConfig(BaseModel):
4842

4943

5044
class BitsAndBytesConfig(BaseModel):
51-
load_in_8bit: Optional[bool] = Field(
52-
False, description="Enable 8-bit quantization with LLM.int8()"
53-
)
45+
load_in_8bit: Optional[bool] = Field(False, description="Enable 8-bit quantization with LLM.int8()")
5446
llm_int8_threshold: Optional[float] = Field(
5547
6.0, description="Outlier threshold for outlier detection in 8-bit quantization"
5648
)
@@ -61,9 +53,7 @@ class BitsAndBytesConfig(BaseModel):
6153
False,
6254
description="Enable splitting model parts between int8 on GPU and fp32 on CPU",
6355
)
64-
llm_int8_has_fp16_weight: Optional[bool] = Field(
65-
False, description="Run LLM.int8() with 16-bit main weights"
66-
)
56+
llm_int8_has_fp16_weight: Optional[bool] = Field(False, description="Run LLM.int8() with 16-bit main weights")
6757

6858
load_in_4bit: Optional[bool] = Field(
6959
True,
@@ -86,14 +76,10 @@ class ModelConfig(BaseModel):
8676
"NousResearch/Llama-2-7b-hf",
8777
description="Path to the model (huggingface repo or local path)",
8878
)
89-
device_map: Optional[str] = Field(
90-
"auto", description="device onto which to load the model"
91-
)
79+
device_map: Optional[str] = Field("auto", description="device onto which to load the model")
9280

9381
quantize: Optional[bool] = Field(False, description="Flag to enable quantization")
94-
bitsandbytes: BitsAndBytesConfig = Field(
95-
None, description="Bits and Bytes configuration"
96-
)
82+
bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration")
9783

9884
# @validator("hf_model_ckpt")
9985
# def validate_model(cls, v, **kwargs):
@@ -116,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs):
116102

117103
class LoraConfig(BaseModel):
118104
r: Optional[int] = Field(8, description="Lora rank")
119-
task_type: Optional[str] = Field(
120-
"CAUSAL_LM", description="Base Model task type during training"
121-
)
105+
task_type: Optional[str] = Field("CAUSAL_LM", description="Base Model task type during training")
122106

123-
lora_alpha: Optional[int] = Field(
124-
16, description="The alpha parameter for Lora scaling"
125-
)
126-
bias: Optional[str] = Field(
127-
"none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
128-
)
129-
lora_dropout: Optional[float] = Field(
130-
0.1, description="The dropout probability for Lora layers"
131-
)
132-
target_modules: Optional[List[str]] = Field(
133-
None, description="The names of the modules to apply Lora to"
134-
)
107+
lora_alpha: Optional[int] = Field(16, description="The alpha parameter for Lora scaling")
108+
bias: Optional[str] = Field("none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'")
109+
lora_dropout: Optional[float] = Field(0.1, description="The dropout probability for Lora layers")
110+
target_modules: Optional[List[str]] = Field(None, description="The names of the modules to apply Lora to")
135111
fan_in_fan_out: Optional[bool] = Field(
136112
False,
137113
description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)",
@@ -140,9 +116,7 @@ class LoraConfig(BaseModel):
140116
None,
141117
description="List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint",
142118
)
143-
layers_to_transform: Optional[Union[List[int], int]] = Field(
144-
None, description="The layer indexes to transform"
145-
)
119+
layers_to_transform: Optional[Union[List[int], int]] = Field(None, description="The layer indexes to transform")
146120
layers_pattern: Optional[str] = Field(None, description="The layer pattern name")
147121
# rank_pattern: Optional[Dict[str, int]] = Field(
148122
# {}, description="The mapping from layer names or regexp expression to ranks"
@@ -155,15 +129,9 @@ class LoraConfig(BaseModel):
155129
# TODO: Get comprehensive Args!
156130
class TrainingArgs(BaseModel):
157131
num_train_epochs: Optional[int] = Field(1, description="Number of training epochs")
158-
per_device_train_batch_size: Optional[int] = Field(
159-
1, description="Batch size per training device"
160-
)
161-
gradient_accumulation_steps: Optional[int] = Field(
162-
1, description="Number of steps for gradient accumulation"
163-
)
164-
gradient_checkpointing: Optional[bool] = Field(
165-
True, description="Flag to enable gradient checkpointing"
166-
)
132+
per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device")
133+
gradient_accumulation_steps: Optional[int] = Field(1, description="Number of steps for gradient accumulation")
134+
gradient_checkpointing: Optional[bool] = Field(True, description="Flag to enable gradient checkpointing")
167135
optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer")
168136
logging_steps: Optional[int] = Field(100, description="Number of logging steps")
169137
learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate")
@@ -172,9 +140,7 @@ class TrainingArgs(BaseModel):
172140
fp16: Optional[bool] = Field(False, description="Flag to enable fp16")
173141
max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm")
174142
warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio")
175-
lr_scheduler_type: Optional[str] = Field(
176-
"constant", description="Learning rate scheduler type"
177-
)
143+
lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type")
178144

179145

180146
# TODO: Get comprehensive Args!

0 commit comments

Comments
 (0)