@@ -66,48 +66,41 @@ Note: Misalignments between the ground truth graphs and prediction segmentation
6666Here is a simple example of evaluating a predicted segmentation.
6767
6868``` python
69- import numpy as np
70- from xlwt import Workbook
71-
72- from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
73- from segmentation_skeleton_metrics.utils.img_util import TiffReader
74-
75-
76- def evaluate ():
77- # Initializations
78- segmentation = TiffReader(segmentation_path)
79- skeleton_metric = SkeletonMetric(
80- groundtruth_pointer,
81- segmentation,
82- fragments_pointer = fragments_pointer,
83- output_dir = output_dir,
84- )
85- full_results, avg_results = skeleton_metric.run()
86-
87- # Report results
88- print (f " \n Averaged Results... " )
89- for key in avg_results.keys():
90- print (f " { key} : { round (avg_results[key], 4 )} " )
91-
92- print (f " \n Total Results... " )
93- print (" # splits:" , np.sum(list (skeleton_metric.split_cnt.values())))
94- print (" # merges:" , np.sum(list (skeleton_metric.merge_cnt.values())))
95-
96- # Save results
97- path = f " { output_dir} /evaluation_results.xls "
98- save_results(path, full_results)
99-
100-
101- if __name__ == " __main__" :
102- # Initializations
103- output_dir = " ./"
104- segmentation_path = " ./pred_labels.tif"
105- fragments_pointer = " ./pred_swcs.zip"
106- groundtruth_pointer = " ./target_swcs.zip"
107-
108- # Run
109- evaluate()
110-
69+
70+ import torch.nn as nn
71+ import torch.optim as optim
72+
73+ from supervoxel_loss.loss import SuperVoxelLoss2D
74+
75+
76+ # Initialization
77+ model = UNet()
78+ optimizer = optim.AdamW(model.parameters(), lr = 1e-4 )
79+
80+ loss_switch_epoch = 10
81+ voxel_loss = nn.BCEWithLogitsLoss()
82+ supervoxel_loss = SuperVoxelLoss2D(alpha = 0.5 , beta = 0.5 , threshold = 0 )
83+
84+ # Main
85+ for epoch in range (n_epochs):
86+ # Set loss function based on the current epoch
87+ if epoch < loss_switch_epoch:
88+ loss_function = voxel_loss
89+ else :
90+ loss_function = supervoxel_loss
91+
92+ # Training loop
93+ for inputs, targets in dataloader:
94+ # Forward pass
95+ preds = model(inputs)
96+
97+ # Compute loss
98+ loss = loss_function(preds, targets)
99+
100+ # Backward pass
101+ optimizer.zero_grad()
102+ loss.backward()
103+ optimizer.step()
111104
112105```
113106
0 commit comments