Skip to content

Commit 33b7c61

Browse files
First commit.
1 parent 2851a64 commit 33b7c61

File tree

4 files changed

+85
-179
lines changed

4 files changed

+85
-179
lines changed

quantllm/config/dataset_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class DatasetConfig:
77
"""Configuration for dataset loading and processing."""
88

99
# Dataset identification
10-
dataset_name_or_path: str
10+
dataset_name: str
1111
dataset_type: str = "huggingface" # huggingface, local, custom
1212
dataset_revision: str = "main"
1313
dataset_split: Optional[str] = None
@@ -56,7 +56,7 @@ def __post_init__(self):
5656
def to_dict(self) -> Dict[str, Any]:
5757
"""Convert configuration to dictionary."""
5858
return {
59-
"dataset_name_or_path": self.dataset_name_or_path,
59+
"dataset_name": self.dataset_name,
6060
"dataset_type": self.dataset_type,
6161
"dataset_revision": self.dataset_revision,
6262
"dataset_split": self.dataset_split,
@@ -89,7 +89,7 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> 'DatasetConfig':
8989

9090
def validate(self) -> bool:
9191
"""Validate configuration values."""
92-
if not self.dataset_name_or_path:
92+
if not self.dataset_name:
9393
raise ValueError("Dataset name or path is required")
9494

9595
if self.dataset_type not in ["huggingface", "local", "custom"]:

quantllm/hub/hub_manager.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ def __init__(
1212
"""Initialize HubManager and login to Hugging Face."""
1313
self.model_id = model_id
1414
self.organization = organization
15-
self.api = HfApi()
15+
self.token = token
1616

17-
if token:
18-
try:
19-
login(token=token)
20-
print(f"Successfully logged in to Hugging Face Hub")
21-
self.token = token
22-
except Exception as e:
23-
print(f"Login failed: {str(e)}")
24-
raise
17+
def login(self):
18+
"""Login to Hugging Face Hub."""
19+
try:
20+
self.api = HfApi(token=self.token)
21+
print("Successfully logged in to Hugging Face Hub.")
22+
except Exception as e:
23+
print(f"Error logging in: {str(e)}")
24+
raise
2525

2626
def push_model(
2727
self,

quantllm/trainer/trainer.py

Lines changed: 73 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
import numpy as np
99
from tqdm import tqdm
10-
import wandb
10+
from wandb import wandb, login
1111
from datetime import datetime
1212
import os
1313
from ..config.training_config import TrainingConfig
@@ -53,6 +53,7 @@ def __init__(
5353
self.hub_manager = hub_manager
5454
self.use_wandb = use_wandb
5555
self.wandb_config = wandb_config or {}
56+
self.wandb_token = self.wandb_config['API_KEY']
5657

5758
# Set device
5859
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -141,8 +142,10 @@ def lr_lambda(current_step: int) -> float:
141142

142143
def _setup_wandb(self):
143144
"""Setup Weights & Biases logging."""
144-
if not wandb.api.api_key:
145-
self.logger.log_warning("Weights & Biases API key not found. Disabling W&B logging.")
145+
if wandb.login(key=self.wandb_token, relogin=True):
146+
self.logger.log_info("Logged in to Weights & Biases")
147+
else:
148+
self.logger.log_error("Failed to log in to Weights & Biases")
146149
self.use_wandb = False
147150
return
148151

@@ -162,45 +165,78 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
162165
outputs = self.model(**batch)
163166
return outputs.loss
164167

165-
def _train_step(self, batch: Dict[str, torch.Tensor]) -> float:
166-
"""Perform a single training step."""
167-
self.model.train()
168-
169-
# Clear gradients
170-
self.optimizer.zero_grad()
171-
172-
# Forward pass with mixed precision
173-
if self.scaler is not None:
174-
with torch.cuda.amp.autocast():
175-
loss = self._compute_loss(batch)
176-
168+
def train_step(self, batch, scaler):
169+
"""Single training step."""
170+
try:
171+
# Move batch to device
172+
batch = {k: v.to(self.device) for k, v in batch.items()}
173+
174+
# Forward pass with modern autocast
175+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
176+
outputs = self.model(**batch)
177+
loss = outputs.loss
178+
177179
# Backward pass with gradient scaling
178-
self.scaler.scale(loss).backward()
180+
scaler.scale(loss).backward()
179181

180-
# Gradient clipping
181-
self.scaler.unscale_(self.optimizer)
182-
torch.nn.utils.clip_grad_norm_(
183-
self.model.parameters(),
184-
self.config.max_grad_norm
185-
)
182+
if self.config.max_grad_norm is not None:
183+
scaler.unscale_(self.optimizer)
184+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
185+
186+
scaler.step(self.optimizer)
187+
scaler.update()
186188

187-
# Optimizer step with gradient scaling
188-
self.scaler.step(self.optimizer)
189-
self.scaler.update()
190-
else:
191-
loss = self._compute_loss(batch)
192-
loss.backward()
189+
self.optimizer.zero_grad()
193190

194-
# Gradient clipping
195-
torch.nn.utils.clip_grad_norm_(
196-
self.model.parameters(),
197-
self.config.max_grad_norm
198-
)
191+
return loss.item()
199192

200-
self.optimizer.step()
193+
except Exception as e:
194+
print(f"Error in training step: {str(e)}")
195+
raise
196+
197+
def train(self):
198+
"""Train the model."""
199+
try:
200+
self.logger.log_info("Starting training")
201201

202-
return loss.item()
203-
202+
# Disable model caching when using gradient checkpointing
203+
if hasattr(self.model.config, 'gradient_checkpointing') and self.model.config.gradient_checkpointing:
204+
self.model.config.use_cache = False
205+
self.logger.log_info("Disabled model caching due to gradient checkpointing")
206+
207+
scaler = torch.cuda.amp.GradScaler()
208+
209+
for epoch in range(self.config.num_epochs):
210+
self.model.train()
211+
total_loss = 0
212+
213+
# Training loop
214+
with tqdm(total=len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.config.num_epochs}") as pbar:
215+
for step, batch in enumerate(self.train_dataloader):
216+
loss = self.train_step(batch, scaler)
217+
total_loss += loss
218+
219+
# Update progress bar
220+
pbar.update(1)
221+
pbar.set_postfix({'loss': f'{loss:.4f}'})
222+
223+
if self.config.save_steps > 0 and (step + 1) % self.config.save_steps == 0:
224+
self._save_checkpoint(epoch, step)
225+
226+
# Epoch end processing
227+
avg_loss = total_loss / len(self.train_dataloader)
228+
self.logger.log_info(f"Epoch {epoch + 1} - Average loss: {avg_loss:.4f}")
229+
230+
if self.config.save_epochs > 0 and (epoch + 1) % self.config.save_epochs == 0:
231+
self._save_checkpoint(epoch)
232+
233+
if self.config.eval_epochs > 0 and (epoch + 1) % self.config.eval_epochs == 0:
234+
self._evaluate()
235+
236+
except Exception as e:
237+
self.logger.log_error(f"Training error: {str(e)}")
238+
raise
239+
204240
def _evaluate(self) -> Dict[str, float]:
205241
"""Evaluate the model on the validation set."""
206242
if self.eval_dataloader is None:
@@ -219,96 +255,6 @@ def _evaluate(self) -> Dict[str, float]:
219255
avg_loss = total_loss / num_batches
220256
return {"eval_loss": avg_loss}
221257

222-
def train(self):
223-
"""Train the model."""
224-
self.logger.log_info("Starting training")
225-
226-
for epoch in range(self.config.num_epochs):
227-
self.epoch = epoch
228-
self.logger.log_info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
229-
230-
# Training loop
231-
total_loss = 0
232-
num_batches = 0
233-
234-
progress_bar = tqdm(self.train_dataloader, desc="Training")
235-
for batch in progress_bar:
236-
# Training step
237-
loss = self._train_step(batch)
238-
total_loss += loss
239-
num_batches += 1
240-
241-
# Update learning rate
242-
if self.scheduler is not None and not isinstance(self.scheduler, ReduceLROnPlateau):
243-
self.scheduler.step()
244-
245-
# Log metrics
246-
if self.global_step % self.config.logging_steps == 0:
247-
avg_loss = total_loss / num_batches
248-
metrics = {
249-
"train_loss": avg_loss,
250-
"learning_rate": self.optimizer.param_groups[0]["lr"],
251-
"epoch": epoch + 1,
252-
"step": self.global_step
253-
}
254-
255-
self.logger.log_metrics(metrics)
256-
if self.use_wandb:
257-
wandb.log(metrics)
258-
259-
# Evaluation and checkpointing
260-
if self.eval_dataloader is not None and self.global_step % self.config.eval_steps == 0:
261-
eval_metrics = self._evaluate()
262-
self.logger.log_metrics(eval_metrics)
263-
if self.use_wandb:
264-
wandb.log(eval_metrics)
265-
266-
# Update learning rate scheduler if using ReduceLROnPlateau
267-
if isinstance(self.scheduler, ReduceLROnPlateau):
268-
self.scheduler.step(eval_metrics["eval_loss"])
269-
270-
# Early stopping and checkpointing
271-
if eval_metrics["eval_loss"] < self.best_metric - self.config.early_stopping_threshold:
272-
self.best_metric = eval_metrics["eval_loss"]
273-
self.patience_counter = 0
274-
275-
# Save checkpoint locally
276-
if self.checkpoint_manager is not None:
277-
self.checkpoint_manager.save_checkpoint(
278-
self.model,
279-
self.optimizer,
280-
self.scheduler,
281-
self.global_step,
282-
self.epoch,
283-
eval_metrics
284-
)
285-
286-
# Push to hub if configured
287-
if self.hub_manager is not None and self.hub_manager.is_logged_in():
288-
try:
289-
self.hub_manager.push_model(
290-
self.model,
291-
commit_message=f"Checkpoint at step {self.global_step} with eval_loss {eval_metrics['eval_loss']:.4f}"
292-
)
293-
self.logger.log_info("Model pushed to hub successfully")
294-
except Exception as e:
295-
self.logger.log_error(f"Failed to push model to hub: {str(e)}")
296-
else:
297-
self.patience_counter += 1
298-
if self.patience_counter >= self.config.early_stopping_patience:
299-
self.logger.log_info("Early stopping triggered")
300-
return
301-
302-
self.global_step += 1
303-
304-
# End of epoch
305-
avg_loss = total_loss / num_batches
306-
self.logger.log_info(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
307-
308-
self.logger.log_info("Training completed")
309-
if self.use_wandb:
310-
wandb.finish()
311-
312258
def save_model(self, output_dir: Union[str, Path]):
313259
"""Save the model and training state."""
314260
output_dir = Path(output_dir)
@@ -345,4 +291,4 @@ def load_model(self, input_dir: Union[str, Path]):
345291
self.optimizer.load_state_dict(training_state["optimizer_state_dict"])
346292
if self.scheduler and training_state["scheduler_state_dict"]:
347293
self.scheduler.load_state_dict(training_state["scheduler_state_dict"])
348-
self.best_metric = training_state["best_metric"]
294+
self.best_metric = training_state["best_metric"]

setup.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

0 commit comments

Comments
 (0)