Skip to content

Commit 99688d2

Browse files
committed
Fix mermaid error via Claude
1 parent 9a5028c commit 99688d2

File tree

2 files changed

+46
-53
lines changed

2 files changed

+46
-53
lines changed

docs/03_training_loop___trainer___.md

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -103,44 +103,38 @@ sequenceDiagram
103103
participant Model
104104
participant OptimizerScheduler as Optimizer / Scheduler
105105
participant LoggerEarlyStop as Logger (W&B) / EarlyStopping
106-
107-
RunFunc->>+Trainer: trainer.train(dl_train, dl_val, epochs)
106+
107+
RunFunc->>Trainer: trainer.train(dl_train, dl_val, epochs)
108+
108109
loop For each Epoch (1 to 'epochs')
109-
Trainer->>+Trainer: train_epoch(dl_train) # Ask Trainer to train one cycle
110-
Trainer->>+DataLoaderTrain: Get next training batch (x, y)
111-
DataLoaderTrain-->>-Trainer: Return batch
112-
Trainer->>+Model: Forward pass: model(x)
113-
Model-->>-Trainer: Return prediction y_pred
114-
Trainer->>Trainer: Calculate loss(y_pred, y)
115-
Trainer->>+OptimizerScheduler: Backpropagate loss & Update weights (optimizer.step())
116-
OptimizerScheduler-->>-Trainer: Weights updated
117-
Trainer->>DataLoaderTrain: Repeat for all batches in dl_train
118-
Trainer-->>-Trainer: Return average train_loss for epoch
119-
120-
Trainer->>+Trainer: val_epoch(dl_val) # Ask Trainer to validate
121-
Trainer->>+DataLoaderVal: Get next validation batch (x, y)
122-
DataLoaderVal-->>-Trainer: Return batch
123-
Trainer->>+Model: Forward pass: model(x) (No gradient tracking)
124-
Model-->>-Trainer: Return prediction y_pred
125-
Trainer->>Trainer: Calculate loss(y_pred, y)
126-
Trainer->>DataLoaderVal: Repeat for all batches in dl_val
127-
Trainer-->>-Trainer: Return average val_loss for epoch
128-
129-
Trainer->>+LoggerEarlyStop: Log metrics (train_loss, val_loss, lr)
130-
LoggerEarlyStop-->>-Trainer: Metrics logged
131-
Trainer->>+LoggerEarlyStop: Check EarlyStopping(val_loss)
110+
Trainer->>DataLoaderTrain: Get training batches
111+
DataLoaderTrain-->>Trainer: Return batches
112+
Trainer->>Model: Forward pass
113+
Model-->>Trainer: Return predictions
114+
Trainer->>Trainer: Calculate loss
115+
Trainer->>OptimizerScheduler: Backpropagate & update weights
116+
OptimizerScheduler-->>Trainer: Weights updated
117+
118+
Trainer->>DataLoaderVal: Get validation batches
119+
DataLoaderVal-->>Trainer: Return batches
120+
Trainer->>Model: Forward pass (no gradients)
121+
Model-->>Trainer: Return predictions
122+
Trainer->>Trainer: Calculate validation loss
123+
124+
Trainer->>LoggerEarlyStop: Log metrics
125+
LoggerEarlyStop-->>Trainer: Metrics logged
126+
Trainer->>LoggerEarlyStop: Check early stopping
127+
132128
alt Early Stop Triggered
133-
LoggerEarlyStop-->>-Trainer: Stop = True
134-
Trainer->>Trainer: Break epoch loop
129+
LoggerEarlyStop-->>Trainer: Stop = True
135130
else Continue Training
136-
LoggerEarlyStop-->>-Trainer: Stop = False
131+
LoggerEarlyStop-->>Trainer: Stop = False
132+
Trainer->>OptimizerScheduler: Adjust learning rate
133+
OptimizerScheduler-->>Trainer: LR adjusted
137134
end
138-
Trainer->>+OptimizerScheduler: Adjust Learning Rate (scheduler.step())
139-
OptimizerScheduler-->>-Trainer: LR adjusted
140-
end
141-
alt Training Finished Normally or Early Stopped
142-
Trainer-->>-RunFunc: Return final_val_loss
143135
end
136+
137+
Trainer-->>RunFunc: Return final_val_loss
144138
```
145139

146140
This diagram shows the cycle: for each epoch, the `Trainer` calls `train_epoch` (which iterates through training batches, performs forward/backward passes, and updates weights) and `val_epoch` (which iterates through validation batches and calculates loss without updating weights). After each epoch, it logs metrics, checks for early stopping, and adjusts the learning rate.

docs/06_pruning_strategy___pflpruner___.md

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,38 +97,37 @@ sequenceDiagram
9797
participant Trainer as Trainer
9898
participant PFLPruner as PFLPruner Instance
9999
participant TrialState as TrialState
100-
101-
OptunaStudy->>+ObjectiveFunc: Start Trial N
102-
ObjectiveFunc->>+Trainer: trainer.train(..., trial=N, pruner=pruner_instance)
100+
101+
OptunaStudy->>ObjectiveFunc: Start Trial N
102+
ObjectiveFunc->>Trainer: trainer.train(..., trial=N, pruner=pruner_instance)
103+
103104
loop For each Epoch (within trainer.train)
104105
Trainer->>Trainer: Run train_epoch()
105106
Trainer->>Trainer: Run val_epoch()
106-
Trainer->>+PFLPruner: pruner.report(trial_id=N, epoch=E, value=val_loss) # PFLPruner 활성화 시작
107-
PFLPruner->>+TrialState: Update loss history for Trial N
108-
TrialState-->>-PFLPruner: Done
107+
Trainer->>PFLPruner: pruner.report(trial_id=N, epoch=E, value=val_loss)
108+
PFLPruner->>TrialState: Update loss history for Trial N
109+
TrialState-->>PFLPruner: Done
109110
PFLPruner->>PFLPruner: Calculate Predicted Final Loss (PFL)
110111
PFLPruner->>PFLPruner: Compare PFL with Top-K finished trials
111-
PFLPruner->>PFLPruner: Check if should_prune() is True?
112+
PFLPruner->>PFLPruner: pruner.should_prune() ?
113+
112114
alt Pruning conditions met
113-
# report 호출에 대한 응답으로 True 반환하며 PFLPruner 비활성화
114-
PFLPruner-->>-Trainer: Return True
115+
PFLPruner-->>Trainer: Return True
115116
Trainer->>Trainer: Raise optuna.TrialPruned exception
116-
Trainer-->>-ObjectiveFunc: Exception caught # Trainer 비활성화
117-
ObjectiveFunc-->>-OptunaStudy: Report Trial N as Pruned # ObjectiveFunc 비활성화
118-
# 여기서 루프가 중단될 수 있음 (break 등 명시적 표현은 Mermaid 표준에 없음)
117+
Trainer-->>ObjectiveFunc: Exception caught
118+
ObjectiveFunc-->>OptunaStudy: Report Trial N as Pruned
119119
else Pruning conditions NOT met
120-
# report 호출에 대한 응답으로 False 반환하며 PFLPruner 비활성화
121-
PFLPruner-->>-Trainer: Return False
120+
PFLPruner-->>Trainer: Return False
122121
Trainer->>Trainer: Continue to next epoch...
123122
end
124123
end
125-
alt Trial Finishes Normally (or Early Stopping outside pruning)
126-
# 루프 종료 후 Trainer 비활성화 (정상 종료 또는 Early Stopping)
127-
Trainer-->>-ObjectiveFunc: Return final_val_loss
128-
ObjectiveFunc->>+PFLPruner: pruner.complete_trial(trial_id=N)
124+
125+
alt Trial Finishes Normally (or Early Stopping)
126+
Trainer-->>ObjectiveFunc: Return final_val_loss
127+
ObjectiveFunc->>PFLPruner: pruner.complete_trial(trial_id=N)
129128
PFLPruner->>PFLPruner: Update Top-K completed trials if necessary
130-
PFLPruner-->>-ObjectiveFunc: Done # PFLPruner 비활성화
131-
ObjectiveFunc-->>-OptunaStudy: Report Trial N result (final_val_loss) # ObjectiveFunc 비활성화
129+
PFLPruner-->>ObjectiveFunc: Done
130+
ObjectiveFunc-->>OptunaStudy: Report Trial N result (final_val_loss)
132131
end
133132
```
134133

0 commit comments

Comments
 (0)