Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ multidict==6.4.3
nest-asyncio==1.6.0
networkx==3.2.1
numpy==2.2.4
onnx==1.18.0
pandas==2.2.3
packaging==25.0
parso==0.8.4
Expand Down
63 changes: 48 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def main():
class_weights = config["training"]["class_weights"] if "class_weights" in config["training"] else None

if "class_weights" in config["training"] and ("oversample" in config["training"] or "undersample" in config["training"]):
raise Exception("Can't use class weights and resampling at the same time.")
raise Exception(
"Can't use class weights and resampling at the same time.")
weight_decay = config["training"]["weight_decay"]
model_name = config["model"]["name"]
image_size = 299 if model_name == "inception_v3" else 224
Expand All @@ -65,7 +66,8 @@ def main():
"minority_transform": torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomVerticalFlip(),
torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
torchvision.transforms.RandomAffine(
degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
]),
"oversample_factor": config["training"]["oversample"]["oversample_factor"],
"oversample_threshold": config["training"]["oversample"]["oversample_threshold"]
Expand All @@ -78,7 +80,8 @@ def main():
elif "curriculum_learning" in config["training"]:
dataset_module = CurriculumLearningDataset
dataset_args = {
"indices": [0] # The list cannot be empty, since the dataloder doesn't accept empty dataset
# The list cannot be empty, since the dataloder doesn't accept empty dataset
"indices": [0]
}

datamodule = ForestDataModule(
Expand All @@ -97,7 +100,8 @@ def main():
gamma=config["training"]["gamma"],
freeze=freeze,
transform=transforms,
weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None,
weight=torch.tensor(
class_weights, dtype=torch.float) if class_weights is not None else None,
learning_rate=learning_rate,
weight_decay=weight_decay
)
Expand All @@ -111,7 +115,8 @@ def main():
callbacks.append(EarlyStopping(monitor=config["training"]["early_stopping"]['monitor'],
patience=config["training"]["early_stopping"]['patience'],
mode=config["training"]["early_stopping"]['mode']))
checkpoint_dir = config["training"].get("checkpoint_dir", "checkpoints/")
checkpoint_dir = config["training"].get(
"checkpoint_dir", "checkpoints/")
callbacks.append(ModelCheckpoint(monitor='val_loss',
mode='min',
save_top_k=1,
Expand Down Expand Up @@ -142,7 +147,7 @@ def main():

wandb_api_key = os.environ.get('WANDB_API_KEY')
wandb.login(key=wandb_api_key)
wandb.init(project="ghost-irim", name=run_name)
wandb_run = wandb.init(project="ghost-irim", name=run_name)

# Log config.yaml to wandb
wandb.save("src/config.yaml")
Expand Down Expand Up @@ -177,7 +182,8 @@ def main():
break

if not best_ckpt_path:
raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.")
raise ValueError(
"No ModelCheckpoint callback found or no best checkpoint available.")

trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
# Callbacks' service
Expand All @@ -186,7 +192,8 @@ def main():
train_metrics = callback.train_metrics
val_metrics = callback.val_metrics
plot_metrics(train_metrics, val_metrics)
wandb.log({'Accuracy and Loss Curves': wandb.Image('src/plots/acc_loss_curves.png')})
wandb.log({'Accuracy and Loss Curves': wandb.Image(
'src/plots/acc_loss_curves.png')})

# Logging plots
preds = model.predictions
Expand All @@ -199,27 +206,53 @@ def main():

# Log metrics per class and classnames
metrics_per_class = calculate_metrics_per_class(targets, preds)
accs = [metrics_per_class[key]['accuracy'] for key in metrics_per_class.keys()]
precs = [metrics_per_class[key]['precision'] for key in metrics_per_class.keys()]
recs = [metrics_per_class[key]['recall'] for key in metrics_per_class.keys()]
accs = [metrics_per_class[key]['accuracy']
for key in metrics_per_class.keys()]
precs = [metrics_per_class[key]['precision']
for key in metrics_per_class.keys()]
recs = [metrics_per_class[key]['recall']
for key in metrics_per_class.keys()]
f1s = [metrics_per_class[key]['f1'] for key in metrics_per_class.keys()]
ious = [metrics_per_class[key]['IoU'] for key in metrics_per_class.keys()]
names_and_labels = [[key, value] for key, value in label_map.items()]
logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious)]
logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [
name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious)]

training_table = wandb.Table(columns=['Class name', 'Label', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'IoU'], data=logged_metrics)
training_table = wandb.Table(columns=[
'Class name', 'Label', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'IoU'], data=logged_metrics)
wandb.log({'Classes': training_table})

# Log confusion matrix, precision-recall curve and roc-auc curve
get_confusion_matrix(preds, targets, class_names=list(label_map.keys()))
get_roc_auc_curve(preds, targets, class_names=list(label_map.keys()))
get_precision_recall_curve(preds, targets, class_names=list(label_map.keys()))
get_precision_recall_curve(
preds, targets, class_names=list(label_map.keys()))

filenames = ['confusion_matrix.png', 'precision_recall_curve.png', 'roc_auc_curve.png']
filenames = ['confusion_matrix.png',
'precision_recall_curve.png', 'roc_auc_curve.png']
titles = ['Confusion Matrix', 'Precision-Recall Curve', 'ROC AUC Curve']
for filename, title in zip(filenames, titles):
wandb.log({title: wandb.Image(f'src/plots/{filename}')})

onnx_filepath = "tile_processing_model.onnx"
model.to_onnx(
onnx_filepath,
export_params=True,
input_names=["input"], # Name of the input tensor
output_names=["output"], # Name of the output tensor
dynamic_axes={
'input': {0: 'batch_size'}, # Make batch size dynamic
'output': {0: 'batch_size'}
}
)
wandb.artifact = wandb.Artifact(
name=f"onnx-model-{wandb_run.id}",
type="model",
description=f"ONNX model from run {wandb_run.id}",
metadata=dict(wandb_run.config)
)
wandb.artifact.add_file(onnx_filepath)

wandb.finish()


Expand Down
1 change: 1 addition & 0 deletions unix-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ multidict==6.4.3
nest-asyncio==1.6.0
networkx==3.2.1
numpy==2.2.4
onnx==1.18.0
pandas==2.2.3
packaging==25.0
parso==0.8.4
Expand Down
Loading