33import argparse
44import json
55from pathlib import Path
6- from typing import Any
76
87import torch
98import yaml
109from torch import amp , nn , optim
1110
1211from src .common import DEVICE , get_dataloader , get_model , train_epoch , validate
12+ from src .early_stopping import EarlyStopping
1313
1414
1515def main () -> None :
1616 """Execute the fine-tuning pipeline."""
17+ choices = ["baseline" , "pipeline1" , "pipeline2" , "pipeline3" ]
1718 parser = argparse .ArgumentParser ()
19+ parser .add_argument ("--pipeline" , choices = choices , type = str , required = True )
1820 parser .add_argument ("--model" , type = str , required = True )
19- parser .add_argument ("--config" , type = str , required = True )
20- parser .add_argument ("--output_dir" , type = str , required = True )
21-
21+ # parser.add_argument("--config", type=str, required=True)
22+ # parser.add_argument("--output_dir", type=str, required=True)
2223 args = parser .parse_args ()
2324
24- with open (args .config ) as conf_file :
25- config : dict [str , Any ] = yaml .safe_load (conf_file )
25+ params_path = Path (args .pipeline ) / "params.yaml"
26+ with open (params_path ) as f :
27+ config = yaml .safe_load (f )
2628
27- out_dir = Path (args .output_dir )
29+ out_dir = Path (args .pipeline / args . model ) / "finetuned"
2830 out_dir .mkdir (parents = True , exist_ok = True )
2931
3032 t_loader = get_dataloader (
@@ -43,6 +45,11 @@ def main() -> None:
4345 model .load_state_dict (torch .load (weights_path , map_location = DEVICE ))
4446 model .to (DEVICE )
4547
48+ best_model_path = out_dir / "model.pth"
49+ early_stopper = EarlyStopping (
50+ alpha = config ["train" ].get ("alpha" , 5.0 ), path = str (best_model_path )
51+ )
52+
4653 # Unfreeze layers
4754 for param in model .parameters ():
4855 param .requires_grad = True
@@ -65,6 +72,15 @@ def main() -> None:
6572 print (
6673 f"Epoch { epoch + 1 } /{ epochs } | T-Loss: { t_loss :.4f} | V-Loss: { v_loss :.4f} "
6774 )
75+
76+ early_stopper (v_loss , epoch + 1 , model )
77+ if early_stopper .stop :
78+ print (
79+ f"Stopping at epoch { epoch + 1 } . "
80+ f"Best model was at epoch { early_stopper .best_epoch } "
81+ )
82+ break
83+
6884 print (f"Model { args .model } fine-tuned successfully!" )
6985
7086 # Saving the model
0 commit comments