@@ -103,44 +103,38 @@ sequenceDiagram
103103 participant Model
104104 participant OptimizerScheduler as Optimizer / Scheduler
105105 participant LoggerEarlyStop as Logger (W&B) / EarlyStopping
106-
107- RunFunc->>+Trainer: trainer.train(dl_train, dl_val, epochs)
106+
107+ RunFunc->>Trainer: trainer.train(dl_train, dl_val, epochs)
108+
108109 loop For each Epoch (1 to 'epochs')
109- Trainer->>+Trainer: train_epoch(dl_train) # Ask Trainer to train one cycle
110- Trainer->>+DataLoaderTrain: Get next training batch (x, y)
111- DataLoaderTrain-->>-Trainer: Return batch
112- Trainer->>+Model: Forward pass: model(x)
113- Model-->>-Trainer: Return prediction y_pred
114- Trainer->>Trainer: Calculate loss(y_pred, y)
115- Trainer->>+OptimizerScheduler: Backpropagate loss & Update weights (optimizer.step())
116- OptimizerScheduler-->>-Trainer: Weights updated
117- Trainer->>DataLoaderTrain: Repeat for all batches in dl_train
118- Trainer-->>-Trainer: Return average train_loss for epoch
119-
120- Trainer->>+Trainer: val_epoch(dl_val) # Ask Trainer to validate
121- Trainer->>+DataLoaderVal: Get next validation batch (x, y)
122- DataLoaderVal-->>-Trainer: Return batch
123- Trainer->>+Model: Forward pass: model(x) (No gradient tracking)
124- Model-->>-Trainer: Return prediction y_pred
125- Trainer->>Trainer: Calculate loss(y_pred, y)
126- Trainer->>DataLoaderVal: Repeat for all batches in dl_val
127- Trainer-->>-Trainer: Return average val_loss for epoch
128-
129- Trainer->>+LoggerEarlyStop: Log metrics (train_loss, val_loss, lr)
130- LoggerEarlyStop-->>-Trainer: Metrics logged
131- Trainer->>+LoggerEarlyStop: Check EarlyStopping(val_loss)
110+ Trainer->>DataLoaderTrain: Get training batches
111+ DataLoaderTrain-->>Trainer: Return batches
112+ Trainer->>Model: Forward pass
113+ Model-->>Trainer: Return predictions
114+ Trainer->>Trainer: Calculate loss
115+ Trainer->>OptimizerScheduler: Backpropagate & update weights
116+ OptimizerScheduler-->>Trainer: Weights updated
117+
118+ Trainer->>DataLoaderVal: Get validation batches
119+ DataLoaderVal-->>Trainer: Return batches
120+ Trainer->>Model: Forward pass (no gradients)
121+ Model-->>Trainer: Return predictions
122+ Trainer->>Trainer: Calculate validation loss
123+
124+ Trainer->>LoggerEarlyStop: Log metrics
125+ LoggerEarlyStop-->>Trainer: Metrics logged
126+ Trainer->>LoggerEarlyStop: Check early stopping
127+
132128 alt Early Stop Triggered
133- LoggerEarlyStop-->>-Trainer: Stop = True
134- Trainer->>Trainer: Break epoch loop
129+ LoggerEarlyStop-->>Trainer: Stop = True
135130 else Continue Training
136- LoggerEarlyStop-->>-Trainer: Stop = False
131+ LoggerEarlyStop-->>Trainer: Stop = False
132+ Trainer->>OptimizerScheduler: Adjust learning rate
133+ OptimizerScheduler-->>Trainer: LR adjusted
137134 end
138- Trainer->>+OptimizerScheduler: Adjust Learning Rate (scheduler.step())
139- OptimizerScheduler-->>-Trainer: LR adjusted
140- end
141- alt Training Finished Normally or Early Stopped
142- Trainer-->>-RunFunc: Return final_val_loss
143135 end
136+
137+ Trainer-->>RunFunc: Return final_val_loss
144138```
145139
146140This diagram shows the cycle: for each epoch, the ` Trainer ` calls ` train_epoch ` (which iterates through training batches, performs forward/backward passes, and updates weights) and ` val_epoch ` (which iterates through validation batches and calculates loss without updating weights). After each epoch, it logs metrics, checks for early stopping, and adjusts the learning rate.
0 commit comments