@@ -45,57 +45,153 @@ from quantllm import (
4545 DatasetSplitter,
4646 FineTuningTrainer,
4747 ModelEvaluator,
48- TrainingConfig,
48+ HubManager,
49+ CheckpointManager,
50+ )
51+ import os
52+ from quantllm.finetune import TrainingLogger
53+ from quantllm.config import (
54+ DatasetConfig,
4955 ModelConfig,
50- DatasetConfig
56+ TrainingConfig,
5157)
5258
5359# Initialize logger
54- from quantllm.finetune import TrainingLogger
5560logger = TrainingLogger()
5661
57- # 1. Dataset Configuration and Loading
62+ # 1. Initialize hub manager first
63+ hub_manager = HubManager(
64+ model_id = " your-username/llama-2-imdb" ,
65+ token = os.getenv(" HF_TOKEN" )
66+ )
67+
68+ # 2. Model Configuration and Loading
69+ model_config = ModelConfig(
70+ model_name = " meta-llama/Llama-3.2-3B" ,
71+ load_in_4bit = True ,
72+ use_lora = True ,
73+ hub_manager = hub_manager
74+ )
75+
76+ model_loader = ModelLoader(model_config)
77+ model = model_loader.get_model()
78+ tokenizer = model_loader.get_tokenizer()
79+
80+ # 3. Dataset Configuration and Loading
5881dataset_config = DatasetConfig(
5982 dataset_name_or_path = " imdb" ,
6083 dataset_type = " huggingface" ,
6184 text_column = " text" ,
62- label_column = " label"
85+ label_column = " label" ,
86+ max_length = 512 ,
87+ train_size = 0.8 ,
88+ val_size = 0.1 ,
89+ test_size = 0.1 ,
90+ hub_manager = hub_manager
6391)
6492
93+ # Load and prepare dataset
6594dataset_loader = DatasetLoader(logger)
66- dataset = dataset_loader.load_hf_dataset(dataset_config.dataset_name_or_path)
95+ dataset = dataset_loader.load_hf_dataset(dataset_config)
96+
97+ # Split dataset
98+ dataset_splitter = DatasetSplitter(logger)
99+ train_dataset, val_dataset, test_dataset = dataset_splitter.train_val_test_split(
100+ dataset,
101+ train_size = dataset_config.train_size,
102+ val_size = dataset_config.val_size,
103+ test_size = dataset_config.test_size
104+ )
67105
68- # 2. Model Configuration and Loading
69- model_config = ModelConfig(
70- model_name_or_path = " meta-llama/Llama-2-7b-hf" ,
71- load_in_4bit = True ,
72- use_lora = True
106+ # 4. Dataset Preprocessing
107+ preprocessor = DatasetPreprocessor(tokenizer, logger)
108+ train_dataset, val_dataset, test_dataset = preprocessor.tokenize_dataset(
109+ train_dataset, val_dataset, test_dataset,
110+ max_length = dataset_config.max_length,
111+ text_column = dataset_config.text_column,
112+ label_column = dataset_config.label_column
73113)
74114
75- model_loader = ModelLoader(
76- model_name = model_config.model_name_or_path,
77- quantization = " 4bit" if model_config.load_in_4bit else None ,
78- use_lora = model_config.use_lora
115+ # Create data loaders
116+ train_dataloader = DataLoader(
117+ train_dataset,
118+ batch_size = 4 ,
119+ shuffle = True ,
120+ num_workers = 4
121+ )
122+ val_dataloader = DataLoader(
123+ val_dataset,
124+ batch_size = 4 ,
125+ shuffle = False ,
126+ num_workers = 4
127+ )
128+ test_dataloader = DataLoader(
129+ test_dataset,
130+ batch_size = 4 ,
131+ shuffle = False ,
132+ num_workers = 4
79133)
80- model = model_loader.get_model()
81- tokenizer = model_loader.get_tokenizer()
82134
83- # 3 . Training Configuration
135+ # 5 . Training Configuration
84136training_config = TrainingConfig(
85137 learning_rate = 2e-4 ,
86138 num_epochs = 3 ,
87- batch_size = 4
139+ batch_size = 4 ,
140+ gradient_accumulation_steps = 4 ,
141+ warmup_steps = 100 ,
142+ logging_steps = 50 ,
143+ eval_steps = 200 ,
144+ save_steps = 500 ,
145+ early_stopping_patience = 3 ,
146+ early_stopping_threshold = 0.01
88147)
89148
90- # 4. Initialize and Run Trainer
149+ # Initialize checkpoint manager
150+ checkpoint_manager = CheckpointManager(
151+ output_dir = " ./checkpoints" ,
152+ save_total_limit = 3
153+ )
154+
155+ # 6. Initialize Trainer
91156trainer = FineTuningTrainer(
92157 model = model,
93158 training_config = training_config,
94159 train_dataloader = train_dataloader,
95160 eval_dataloader = val_dataloader,
96- logger = logger
161+ logger = logger,
162+ checkpoint_manager = checkpoint_manager,
163+ hub_manager = hub_manager,
164+ use_wandb = True ,
165+ wandb_config = {
166+ " project" : " quantllm-imdb" ,
167+ " name" : " llama-2-imdb-finetuning"
168+ }
97169)
170+
171+ # 7. Train the model
98172trainer.train()
173+
174+ # 8. Evaluate on test set
175+ evaluator = ModelEvaluator(
176+ model = model,
177+ eval_dataloader = test_dataloader,
178+ metrics = [
179+ lambda preds , labels , _ : (preds.argmax(dim = - 1 ) == labels).float().mean().item() # Accuracy
180+ ],
181+ logger = logger
182+ )
183+
184+ test_metrics = evaluator.evaluate()
185+
186+ # 9. Save final model
187+ trainer.save_model(" ./final_model" )
188+
189+ # 10. Push to Hub if logged in
190+ if hub_manager.is_logged_in():
191+ hub_manager.push_model(
192+ model,
193+ commit_message = f " Final model with test accuracy: { test_metrics.get(' accuracy' , 0 ):.4f } "
194+ )
99195```
100196
101197### ⚙️ Advanced Usage
@@ -105,7 +201,7 @@ trainer.train()
105201Create a config file (e.g., ` config.yaml ` ):
106202``` yaml
107203model :
108- model_name_or_path : " meta-llama/Llama-2-7b-hf "
204+ model_name : " meta-llama/Llama-3.2-3B "
109205 load_in_4bit : true
110206 use_lora : true
111207 lora_config :
@@ -118,45 +214,21 @@ dataset:
118214 text_column : " text"
119215 label_column : " label"
120216 max_length : 512
217+ train_size : 0.8
218+ val_size : 0.1
219+ test_size : 0.1
121220
122221training :
123222 learning_rate : 2e-4
124223 num_epochs : 3
125224 batch_size : 4
126225 gradient_accumulation_steps : 4
127- ` ` `
128-
129- #### Hub Integration
130-
131- ` ` ` python
132- from quantllm.hub import HubManager
133-
134- hub_manager = HubManager(
135- model_id="your-username/llama-2-imdb",
136- token=os.getenv("HF_TOKEN")
137- )
138-
139- if hub_manager.is_logged_in() :
140- hub_manager.push_model(
141- model,
142- commit_message="Trained model with custom configuration"
143- )
144- ` ` `
145-
146- #### Evaluation
147-
148- ` ` ` python
149- from quantllm.finetune import ModelEvaluator
150-
151- evaluator = ModelEvaluator(
152- model=model,
153- eval_dataloader=test_dataloader,
154- metrics=[
155- lambda preds, labels, _ : (preds.argmax(dim=-1) == labels).float().mean().item()
156- ]
157- )
158-
159- metrics = evaluator.evaluate()
226+ warmup_steps : 100
227+ logging_steps : 50
228+ eval_steps : 200
229+ save_steps : 500
230+ early_stopping_patience : 3
231+ early_stopping_threshold : 0.01
160232` ` `
161233
162234## 📚 Documentation
@@ -165,14 +237,10 @@ metrics = evaluator.evaluate()
165237
166238` ` ` python
167239model_config = ModelConfig(
168- model_name_or_path = " meta-llama/Llama-2-7b-hf " ,
240+ model_name ="meta-llama/Llama-3.2-3B ",
169241 load_in_4bit=True,
170242 use_lora=True,
171- lora_config = {
172- " r" : 16 ,
173- " lora_alpha" : 32 ,
174- " target_modules" : [" q_proj" , " v_proj" ]
175- }
243+ hub_manager=hub_manager
176244)
177245```
178246
@@ -187,7 +255,8 @@ dataset_config = DatasetConfig(
187255 max_length = 512 ,
188256 train_size = 0.8 ,
189257 val_size = 0.1 ,
190- test_size = 0.1
258+ test_size = 0.1 ,
259+ hub_manager = hub_manager
191260)
192261```
193262
@@ -203,7 +272,8 @@ training_config = TrainingConfig(
203272 logging_steps = 50 ,
204273 eval_steps = 200 ,
205274 save_steps = 500 ,
206- early_stopping_patience = 3
275+ early_stopping_patience = 3 ,
276+ early_stopping_threshold = 0.01
207277)
208278```
209279
0 commit comments