Skip to content

Commit 331f324

Browse files
lucasalvaaSimoCimmiMorganVitiello
committed
Fix ruff format
Co-authored-by: SimoCimmi <simonecimmino2004@gmail.com> Co-authored-by: Morgan Vitiello <morgan.vitiello06@gmail.com>
1 parent c342566 commit 331f324

File tree

4 files changed

+13
-16
lines changed

4 files changed

+13
-16
lines changed

src/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_model(model_name: str, num_classes: int) -> nn.Module:
6868

6969
return model.to(DEVICE)
7070

71+
7172
def validate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> float:
7273
"""Calculate average loss on the validation set.
7374
@@ -90,6 +91,7 @@ def validate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> floa
9091
running_loss += loss.item() * images.size(0)
9192
return running_loss / len(loader.dataset)
9293

94+
9395
def train_epoch(
9496
model: nn.Module,
9597
loader: DataLoader,

src/evaluate.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def evaluate(
17-
model: torch.nn.Module, loader: DataLoader
17+
model: torch.nn.Module, loader: DataLoader
1818
) -> Tuple[float, float, float, List[int], List[int]]:
1919
"""Evaluate model and return metrics and predictions.
2020
@@ -46,7 +46,7 @@ def evaluate(
4646
size = len(loader.dataset)
4747

4848
# average='macro' è standard per il multiclasse (media non pesata delle classi)
49-
precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
49+
precision = precision_score(all_labels, all_preds, average="macro", zero_division=0)
5050

5151
return top1 / size, top3 / size, precision, all_labels, all_preds
5252

@@ -56,9 +56,7 @@ def main() -> None:
5656
parser = argparse.ArgumentParser()
5757
parser.add_argument("--model", type=str, required=True)
5858
parser.add_argument("--config", type=str, required=True)
59-
6059
parser.add_argument("--output_dir", type=str, required=True)
61-
6260
parser.add_argument("--model_path", type=str, default=None)
6361

6462
args = parser.parse_args()
@@ -81,18 +79,17 @@ def main() -> None:
8179
weights_path = Path(args.model_path) if args.model_path else out_dir / "model.pth"
8280
print(f"[*] Loading weights from: {weights_path}")
8381

84-
model.load_state_dict(torch.load(weights_path, map_location=DEVICE,
85-
weights_only=True))
82+
model.load_state_dict(
83+
torch.load(weights_path, map_location=DEVICE, weights_only=True)
84+
)
8685

8786
t1, t3, prec, labels, preds = evaluate(model, test_loader)
8887

8988
# Save Metrics
9089
with open(out_dir / "metrics.json", "w") as f:
91-
json.dump({
92-
"top1": t1 * 100,
93-
"top3": t3 * 100,
94-
"precision": prec * 100
95-
}, f, indent=4)
90+
json.dump(
91+
{"top1": t1 * 100, "top3": t3 * 100, "precision": prec * 100}, f, indent=4
92+
)
9693

9794
import csv
9895

@@ -103,7 +100,7 @@ def main() -> None:
103100

104101
output_path = out_dir / "cm_data.csv"
105102

106-
# Scrittura del file CSV
103+
# CSV file for confusion matrix
107104
with open(output_path, "w", newline="", encoding="utf-8") as f:
108105
fieldnames = ["actual", "predicted"]
109106
writer = csv.DictWriter(f, fieldnames=fieldnames)

src/finetune.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def main() -> None:
2727
out_dir = Path(args.output_dir)
2828
out_dir.mkdir(parents=True, exist_ok=True)
2929

30-
3130
t_loader = get_dataloader(
3231
data_path=Path(config["finetuning"]["data_path"]),
3332
batch_size=config["finetuning"]["batch_size"],
@@ -63,9 +62,7 @@ def main() -> None:
6362
history.append({"epoch": epoch + 1, "train_loss": t_loss, "val_loss": v_loss})
6463

6564
print(
66-
f"Epoch {epoch + 1}/{epochs} | "
67-
f"T-Loss: {t_loss:.4f} | "
68-
f"V-Loss: {v_loss:.4f}"
65+
f"Epoch {epoch + 1}/{epochs} | T-Loss: {t_loss:.4f} | V-Loss: {v_loss:.4f}"
6966
)
7067
print(f"Model {args.model} fine-tuned successfully!")
7168

src/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
4141
output = torch.where(index, x_m, x)
4242
return functional.cross_entropy(self.s * output, target)
4343

44+
4445
def main() -> None:
4546
"""Entry point for training."""
4647
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)