77from pathlib import Path
88import numpy as np
99from tqdm import tqdm
10- import wandb
10+ from wandb import wandb , login
1111from datetime import datetime
1212import os
1313from ..config .training_config import TrainingConfig
@@ -53,6 +53,7 @@ def __init__(
5353 self .hub_manager = hub_manager
5454 self .use_wandb = use_wandb
5555 self .wandb_config = wandb_config or {}
56+ self .wandb_token = self .wandb_config ['API_KEY' ]
5657
5758 # Set device
5859 self .device = device or ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -141,8 +142,10 @@ def lr_lambda(current_step: int) -> float:
141142
142143 def _setup_wandb (self ):
143144 """Setup Weights & Biases logging."""
144- if not wandb .api .api_key :
145- self .logger .log_warning ("Weights & Biases API key not found. Disabling W&B logging." )
145+ if wandb .login (key = self .wandb_token , relogin = True ):
146+ self .logger .log_info ("Logged in to Weights & Biases" )
147+ else :
148+ self .logger .log_error ("Failed to log in to Weights & Biases" )
146149 self .use_wandb = False
147150 return
148151
@@ -162,45 +165,78 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
162165 outputs = self .model (** batch )
163166 return outputs .loss
164167
165- def _train_step (self , batch : Dict [str , torch .Tensor ]) -> float :
166- """Perform a single training step."""
167- self .model .train ()
168-
169- # Clear gradients
170- self .optimizer .zero_grad ()
171-
172- # Forward pass with mixed precision
173- if self .scaler is not None :
174- with torch .cuda .amp .autocast ():
175- loss = self ._compute_loss (batch )
176-
168+ def train_step (self , batch , scaler ):
169+ """Single training step."""
170+ try :
171+ # Move batch to device
172+ batch = {k : v .to (self .device ) for k , v in batch .items ()}
173+
174+ # Forward pass with modern autocast
175+ with torch .amp .autocast (device_type = 'cuda' , dtype = torch .float16 ):
176+ outputs = self .model (** batch )
177+ loss = outputs .loss
178+
177179 # Backward pass with gradient scaling
178- self . scaler .scale (loss ).backward ()
180+ scaler .scale (loss ).backward ()
179181
180- # Gradient clipping
181- self . scaler .unscale_ (self .optimizer )
182- torch .nn .utils .clip_grad_norm_ (
183- self . model . parameters (),
184- self .config . max_grad_norm
185- )
182+ if self . config . max_grad_norm is not None :
183+ scaler .unscale_ (self .optimizer )
184+ torch .nn .utils .clip_grad_norm_ (self . model . parameters (), self . config . max_grad_norm )
185+
186+ scaler . step ( self .optimizer )
187+ scaler . update ( )
186188
187- # Optimizer step with gradient scaling
188- self .scaler .step (self .optimizer )
189- self .scaler .update ()
190- else :
191- loss = self ._compute_loss (batch )
192- loss .backward ()
189+ self .optimizer .zero_grad ()
193190
194- # Gradient clipping
195- torch .nn .utils .clip_grad_norm_ (
196- self .model .parameters (),
197- self .config .max_grad_norm
198- )
191+ return loss .item ()
199192
200- self .optimizer .step ()
193+ except Exception as e :
194+ print (f"Error in training step: { str (e )} " )
195+ raise
196+
197+ def train (self ):
198+ """Train the model."""
199+ try :
200+ self .logger .log_info ("Starting training" )
201201
202- return loss .item ()
203-
202+ # Disable model caching when using gradient checkpointing
203+ if hasattr (self .model .config , 'gradient_checkpointing' ) and self .model .config .gradient_checkpointing :
204+ self .model .config .use_cache = False
205+ self .logger .log_info ("Disabled model caching due to gradient checkpointing" )
206+
207+ scaler = torch .cuda .amp .GradScaler ()
208+
209+ for epoch in range (self .config .num_epochs ):
210+ self .model .train ()
211+ total_loss = 0
212+
213+ # Training loop
214+ with tqdm (total = len (self .train_dataloader ), desc = f"Epoch { epoch + 1 } /{ self .config .num_epochs } " ) as pbar :
215+ for step , batch in enumerate (self .train_dataloader ):
216+ loss = self .train_step (batch , scaler )
217+ total_loss += loss
218+
219+ # Update progress bar
220+ pbar .update (1 )
221+ pbar .set_postfix ({'loss' : f'{ loss :.4f} ' })
222+
223+ if self .config .save_steps > 0 and (step + 1 ) % self .config .save_steps == 0 :
224+ self ._save_checkpoint (epoch , step )
225+
226+ # Epoch end processing
227+ avg_loss = total_loss / len (self .train_dataloader )
228+ self .logger .log_info (f"Epoch { epoch + 1 } - Average loss: { avg_loss :.4f} " )
229+
230+ if self .config .save_epochs > 0 and (epoch + 1 ) % self .config .save_epochs == 0 :
231+ self ._save_checkpoint (epoch )
232+
233+ if self .config .eval_epochs > 0 and (epoch + 1 ) % self .config .eval_epochs == 0 :
234+ self ._evaluate ()
235+
236+ except Exception as e :
237+ self .logger .log_error (f"Training error: { str (e )} " )
238+ raise
239+
204240 def _evaluate (self ) -> Dict [str , float ]:
205241 """Evaluate the model on the validation set."""
206242 if self .eval_dataloader is None :
@@ -219,96 +255,6 @@ def _evaluate(self) -> Dict[str, float]:
219255 avg_loss = total_loss / num_batches
220256 return {"eval_loss" : avg_loss }
221257
222- def train (self ):
223- """Train the model."""
224- self .logger .log_info ("Starting training" )
225-
226- for epoch in range (self .config .num_epochs ):
227- self .epoch = epoch
228- self .logger .log_info (f"Epoch { epoch + 1 } /{ self .config .num_epochs } " )
229-
230- # Training loop
231- total_loss = 0
232- num_batches = 0
233-
234- progress_bar = tqdm (self .train_dataloader , desc = "Training" )
235- for batch in progress_bar :
236- # Training step
237- loss = self ._train_step (batch )
238- total_loss += loss
239- num_batches += 1
240-
241- # Update learning rate
242- if self .scheduler is not None and not isinstance (self .scheduler , ReduceLROnPlateau ):
243- self .scheduler .step ()
244-
245- # Log metrics
246- if self .global_step % self .config .logging_steps == 0 :
247- avg_loss = total_loss / num_batches
248- metrics = {
249- "train_loss" : avg_loss ,
250- "learning_rate" : self .optimizer .param_groups [0 ]["lr" ],
251- "epoch" : epoch + 1 ,
252- "step" : self .global_step
253- }
254-
255- self .logger .log_metrics (metrics )
256- if self .use_wandb :
257- wandb .log (metrics )
258-
259- # Evaluation and checkpointing
260- if self .eval_dataloader is not None and self .global_step % self .config .eval_steps == 0 :
261- eval_metrics = self ._evaluate ()
262- self .logger .log_metrics (eval_metrics )
263- if self .use_wandb :
264- wandb .log (eval_metrics )
265-
266- # Update learning rate scheduler if using ReduceLROnPlateau
267- if isinstance (self .scheduler , ReduceLROnPlateau ):
268- self .scheduler .step (eval_metrics ["eval_loss" ])
269-
270- # Early stopping and checkpointing
271- if eval_metrics ["eval_loss" ] < self .best_metric - self .config .early_stopping_threshold :
272- self .best_metric = eval_metrics ["eval_loss" ]
273- self .patience_counter = 0
274-
275- # Save checkpoint locally
276- if self .checkpoint_manager is not None :
277- self .checkpoint_manager .save_checkpoint (
278- self .model ,
279- self .optimizer ,
280- self .scheduler ,
281- self .global_step ,
282- self .epoch ,
283- eval_metrics
284- )
285-
286- # Push to hub if configured
287- if self .hub_manager is not None and self .hub_manager .is_logged_in ():
288- try :
289- self .hub_manager .push_model (
290- self .model ,
291- commit_message = f"Checkpoint at step { self .global_step } with eval_loss { eval_metrics ['eval_loss' ]:.4f} "
292- )
293- self .logger .log_info ("Model pushed to hub successfully" )
294- except Exception as e :
295- self .logger .log_error (f"Failed to push model to hub: { str (e )} " )
296- else :
297- self .patience_counter += 1
298- if self .patience_counter >= self .config .early_stopping_patience :
299- self .logger .log_info ("Early stopping triggered" )
300- return
301-
302- self .global_step += 1
303-
304- # End of epoch
305- avg_loss = total_loss / num_batches
306- self .logger .log_info (f"Epoch { epoch + 1 } completed. Average loss: { avg_loss :.4f} " )
307-
308- self .logger .log_info ("Training completed" )
309- if self .use_wandb :
310- wandb .finish ()
311-
312258 def save_model (self , output_dir : Union [str , Path ]):
313259 """Save the model and training state."""
314260 output_dir = Path (output_dir )
@@ -345,4 +291,4 @@ def load_model(self, input_dir: Union[str, Path]):
345291 self .optimizer .load_state_dict (training_state ["optimizer_state_dict" ])
346292 if self .scheduler and training_state ["scheduler_state_dict" ]:
347293 self .scheduler .load_state_dict (training_state ["scheduler_state_dict" ])
348- self .best_metric = training_state ["best_metric" ]
294+ self .best_metric = training_state ["best_metric" ]
0 commit comments