Skip to content

Commit 5720639

Browse files
First commit.
1 parent a7b52f1 commit 5720639

29 files changed

+2498
-63
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ __pycache__/
77
*.so
88
*.dylib
99

10+
upcoming.md
11+
1012
logs
1113

1214
*.pypirc

README.md

Lines changed: 133 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -45,57 +45,153 @@ from quantllm import (
4545
DatasetSplitter,
4646
FineTuningTrainer,
4747
ModelEvaluator,
48-
TrainingConfig,
48+
HubManager,
49+
CheckpointManager,
50+
)
51+
import os
52+
from quantllm.finetune import TrainingLogger
53+
from quantllm.config import (
54+
DatasetConfig,
4955
ModelConfig,
50-
DatasetConfig
56+
TrainingConfig,
5157
)
5258

5359
# Initialize logger
54-
from quantllm.finetune import TrainingLogger
5560
logger = TrainingLogger()
5661

57-
# 1. Dataset Configuration and Loading
62+
# 1. Initialize hub manager first
63+
hub_manager = HubManager(
64+
model_id="your-username/llama-2-imdb",
65+
token=os.getenv("HF_TOKEN")
66+
)
67+
68+
# 2. Model Configuration and Loading
69+
model_config = ModelConfig(
70+
model_name="meta-llama/Llama-3.2-3B",
71+
load_in_4bit=True,
72+
use_lora=True,
73+
hub_manager=hub_manager
74+
)
75+
76+
model_loader = ModelLoader(model_config)
77+
model = model_loader.get_model()
78+
tokenizer = model_loader.get_tokenizer()
79+
80+
# 3. Dataset Configuration and Loading
5881
dataset_config = DatasetConfig(
5982
dataset_name_or_path="imdb",
6083
dataset_type="huggingface",
6184
text_column="text",
62-
label_column="label"
85+
label_column="label",
86+
max_length=512,
87+
train_size=0.8,
88+
val_size=0.1,
89+
test_size=0.1,
90+
hub_manager=hub_manager
6391
)
6492

93+
# Load and prepare dataset
6594
dataset_loader = DatasetLoader(logger)
66-
dataset = dataset_loader.load_hf_dataset(dataset_config.dataset_name_or_path)
95+
dataset = dataset_loader.load_hf_dataset(dataset_config)
96+
97+
# Split dataset
98+
dataset_splitter = DatasetSplitter(logger)
99+
train_dataset, val_dataset, test_dataset = dataset_splitter.train_val_test_split(
100+
dataset,
101+
train_size=dataset_config.train_size,
102+
val_size=dataset_config.val_size,
103+
test_size=dataset_config.test_size
104+
)
67105

68-
# 2. Model Configuration and Loading
69-
model_config = ModelConfig(
70-
model_name_or_path="meta-llama/Llama-2-7b-hf",
71-
load_in_4bit=True,
72-
use_lora=True
106+
# 4. Dataset Preprocessing
107+
preprocessor = DatasetPreprocessor(tokenizer, logger)
108+
train_dataset, val_dataset, test_dataset = preprocessor.tokenize_dataset(
109+
train_dataset, val_dataset, test_dataset,
110+
max_length=dataset_config.max_length,
111+
text_column=dataset_config.text_column,
112+
label_column=dataset_config.label_column
73113
)
74114

75-
model_loader = ModelLoader(
76-
model_name=model_config.model_name_or_path,
77-
quantization="4bit" if model_config.load_in_4bit else None,
78-
use_lora=model_config.use_lora
115+
# Create data loaders
116+
train_dataloader = DataLoader(
117+
train_dataset,
118+
batch_size=4,
119+
shuffle=True,
120+
num_workers=4
121+
)
122+
val_dataloader = DataLoader(
123+
val_dataset,
124+
batch_size=4,
125+
shuffle=False,
126+
num_workers=4
127+
)
128+
test_dataloader = DataLoader(
129+
test_dataset,
130+
batch_size=4,
131+
shuffle=False,
132+
num_workers=4
79133
)
80-
model = model_loader.get_model()
81-
tokenizer = model_loader.get_tokenizer()
82134

83-
# 3. Training Configuration
135+
# 5. Training Configuration
84136
training_config = TrainingConfig(
85137
learning_rate=2e-4,
86138
num_epochs=3,
87-
batch_size=4
139+
batch_size=4,
140+
gradient_accumulation_steps=4,
141+
warmup_steps=100,
142+
logging_steps=50,
143+
eval_steps=200,
144+
save_steps=500,
145+
early_stopping_patience=3,
146+
early_stopping_threshold=0.01
88147
)
89148

90-
# 4. Initialize and Run Trainer
149+
# Initialize checkpoint manager
150+
checkpoint_manager = CheckpointManager(
151+
output_dir="./checkpoints",
152+
save_total_limit=3
153+
)
154+
155+
# 6. Initialize Trainer
91156
trainer = FineTuningTrainer(
92157
model=model,
93158
training_config=training_config,
94159
train_dataloader=train_dataloader,
95160
eval_dataloader=val_dataloader,
96-
logger=logger
161+
logger=logger,
162+
checkpoint_manager=checkpoint_manager,
163+
hub_manager=hub_manager,
164+
use_wandb=True,
165+
wandb_config={
166+
"project": "quantllm-imdb",
167+
"name": "llama-2-imdb-finetuning"
168+
}
97169
)
170+
171+
# 7. Train the model
98172
trainer.train()
173+
174+
# 8. Evaluate on test set
175+
evaluator = ModelEvaluator(
176+
model=model,
177+
eval_dataloader=test_dataloader,
178+
metrics=[
179+
lambda preds, labels, _: (preds.argmax(dim=-1) == labels).float().mean().item() # Accuracy
180+
],
181+
logger=logger
182+
)
183+
184+
test_metrics = evaluator.evaluate()
185+
186+
# 9. Save final model
187+
trainer.save_model("./final_model")
188+
189+
# 10. Push to Hub if logged in
190+
if hub_manager.is_logged_in():
191+
hub_manager.push_model(
192+
model,
193+
commit_message=f"Final model with test accuracy: {test_metrics.get('accuracy', 0):.4f}"
194+
)
99195
```
100196

101197
### ⚙️ Advanced Usage
@@ -105,7 +201,7 @@ trainer.train()
105201
Create a config file (e.g., `config.yaml`):
106202
```yaml
107203
model:
108-
model_name_or_path: "meta-llama/Llama-2-7b-hf"
204+
model_name: "meta-llama/Llama-3.2-3B"
109205
load_in_4bit: true
110206
use_lora: true
111207
lora_config:
@@ -118,45 +214,21 @@ dataset:
118214
text_column: "text"
119215
label_column: "label"
120216
max_length: 512
217+
train_size: 0.8
218+
val_size: 0.1
219+
test_size: 0.1
121220

122221
training:
123222
learning_rate: 2e-4
124223
num_epochs: 3
125224
batch_size: 4
126225
gradient_accumulation_steps: 4
127-
```
128-
129-
#### Hub Integration
130-
131-
```python
132-
from quantllm.hub import HubManager
133-
134-
hub_manager = HubManager(
135-
model_id="your-username/llama-2-imdb",
136-
token=os.getenv("HF_TOKEN")
137-
)
138-
139-
if hub_manager.is_logged_in():
140-
hub_manager.push_model(
141-
model,
142-
commit_message="Trained model with custom configuration"
143-
)
144-
```
145-
146-
#### Evaluation
147-
148-
```python
149-
from quantllm.finetune import ModelEvaluator
150-
151-
evaluator = ModelEvaluator(
152-
model=model,
153-
eval_dataloader=test_dataloader,
154-
metrics=[
155-
lambda preds, labels, _: (preds.argmax(dim=-1) == labels).float().mean().item()
156-
]
157-
)
158-
159-
metrics = evaluator.evaluate()
226+
warmup_steps: 100
227+
logging_steps: 50
228+
eval_steps: 200
229+
save_steps: 500
230+
early_stopping_patience: 3
231+
early_stopping_threshold: 0.01
160232
```
161233
162234
## 📚 Documentation
@@ -165,14 +237,10 @@ metrics = evaluator.evaluate()
165237
166238
```python
167239
model_config = ModelConfig(
168-
model_name_or_path="meta-llama/Llama-2-7b-hf",
240+
model_name="meta-llama/Llama-3.2-3B",
169241
load_in_4bit=True,
170242
use_lora=True,
171-
lora_config={
172-
"r": 16,
173-
"lora_alpha": 32,
174-
"target_modules": ["q_proj", "v_proj"]
175-
}
243+
hub_manager=hub_manager
176244
)
177245
```
178246

@@ -187,7 +255,8 @@ dataset_config = DatasetConfig(
187255
max_length=512,
188256
train_size=0.8,
189257
val_size=0.1,
190-
test_size=0.1
258+
test_size=0.1,
259+
hub_manager=hub_manager
191260
)
192261
```
193262

@@ -203,7 +272,8 @@ training_config = TrainingConfig(
203272
logging_steps=50,
204273
eval_steps=200,
205274
save_steps=500,
206-
early_stopping_patience=3
275+
early_stopping_patience=3,
276+
early_stopping_threshold=0.01
207277
)
208278
```
209279

examples/basic_usage.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from quantllm import QuantizedLLM
2+
3+
def main():
4+
# Initialize model
5+
model = QuantizedLLM(
6+
model_name="meta-llama/Llama-2-7b-hf",
7+
quantization="4bit",
8+
use_lora=True,
9+
push_to_hub=False # Set to True if you want to push to Hub
10+
)
11+
12+
# Load dataset
13+
print("Loading dataset...")
14+
model.load_dataset("imdb", split="train[:1000]") # Using a small subset for demo
15+
16+
# Fine-tune
17+
print("Starting fine-tuning...")
18+
model.finetune(
19+
epochs=1,
20+
batch_size=4,
21+
learning_rate=2e-4
22+
)
23+
24+
# Save checkpoint
25+
print("Saving checkpoint...")
26+
model.save_checkpoint("checkpoints/demo_checkpoint")
27+
28+
print("Done!")
29+
30+
if __name__ == "__main__":
31+
main()

0 commit comments

Comments
 (0)