1414
1515
1616def evaluate (
17- model : torch .nn .Module , loader : DataLoader
17+ model : torch .nn .Module , loader : DataLoader
1818) -> Tuple [float , float , float , List [int ], List [int ]]:
1919 """Evaluate model and return metrics and predictions.
2020
@@ -46,7 +46,7 @@ def evaluate(
4646 size = len (loader .dataset )
4747
4848 # average='macro' è standard per il multiclasse (media non pesata delle classi)
49- precision = precision_score (all_labels , all_preds , average = ' macro' , zero_division = 0 )
49+ precision = precision_score (all_labels , all_preds , average = " macro" , zero_division = 0 )
5050
5151 return top1 / size , top3 / size , precision , all_labels , all_preds
5252
@@ -56,9 +56,7 @@ def main() -> None:
5656 parser = argparse .ArgumentParser ()
5757 parser .add_argument ("--model" , type = str , required = True )
5858 parser .add_argument ("--config" , type = str , required = True )
59-
6059 parser .add_argument ("--output_dir" , type = str , required = True )
61-
6260 parser .add_argument ("--model_path" , type = str , default = None )
6361
6462 args = parser .parse_args ()
@@ -81,18 +79,17 @@ def main() -> None:
8179 weights_path = Path (args .model_path ) if args .model_path else out_dir / "model.pth"
8280 print (f"[*] Loading weights from: { weights_path } " )
8381
84- model .load_state_dict (torch .load (weights_path , map_location = DEVICE ,
85- weights_only = True ))
82+ model .load_state_dict (
83+ torch .load (weights_path , map_location = DEVICE , weights_only = True )
84+ )
8685
8786 t1 , t3 , prec , labels , preds = evaluate (model , test_loader )
8887
8988 # Save Metrics
9089 with open (out_dir / "metrics.json" , "w" ) as f :
91- json .dump ({
92- "top1" : t1 * 100 ,
93- "top3" : t3 * 100 ,
94- "precision" : prec * 100
95- }, f , indent = 4 )
90+ json .dump (
91+ {"top1" : t1 * 100 , "top3" : t3 * 100 , "precision" : prec * 100 }, f , indent = 4
92+ )
9693
9794 import csv
9895
@@ -103,7 +100,7 @@ def main() -> None:
103100
104101 output_path = out_dir / "cm_data.csv"
105102
106- # Scrittura del file CSV
103+ # CSV file for confusion matrix
107104 with open (output_path , "w" , newline = "" , encoding = "utf-8" ) as f :
108105 fieldnames = ["actual" , "predicted" ]
109106 writer = csv .DictWriter (f , fieldnames = fieldnames )
0 commit comments