@@ -67,40 +67,47 @@ Here is a simple example of evaluating a predicted segmentation.
6767
6868``` python
6969
70- import torch.nn as nn
71- import torch.optim as optim
72-
73- from supervoxel_loss.loss import SuperVoxelLoss2D
74-
75-
76- # Initialization
77- model = UNet()
78- optimizer = optim.AdamW(model.parameters(), lr = 1e-4 )
79-
80- loss_switch_epoch = 10
81- voxel_loss = nn.BCEWithLogitsLoss()
82- supervoxel_loss = SuperVoxelLoss2D(alpha = 0.5 , beta = 0.5 , threshold = 0 )
83-
84- # Main
85- for epoch in range (n_epochs):
86- # Set loss function based on the current epoch
87- if epoch < loss_switch_epoch:
88- loss_function = voxel_loss
89- else :
90- loss_function = supervoxel_loss
91-
92- # Training loop
93- for inputs, targets in dataloader:
94- # Forward pass
95- preds = model(inputs)
96-
97- # Compute loss
98- loss = loss_function(preds, targets)
99-
100- # Backward pass
101- optimizer.zero_grad()
102- loss.backward()
103- optimizer.step()
70+ import numpy as np
71+ from xlwt import Workbook
72+
73+ from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
74+ from segmentation_skeleton_metrics.utils.img_util import TiffReader
75+
76+
77+ def evaluate ():
78+ # Initializations
79+ segmentation = TiffReader(segmentation_path)
80+ skeleton_metric = SkeletonMetric(
81+ groundtruth_pointer,
82+ segmentation,
83+ fragments_pointer = fragments_pointer,
84+ output_dir = output_dir,
85+ )
86+ full_results, avg_results = skeleton_metric.run()
87+
88+ # Report results
89+ print (f " \n Averaged Results... " )
90+ for key in avg_results.keys():
91+ print (f " { key} : { round (avg_results[key], 4 )} " )
92+
93+ print (f " \n Total Results... " )
94+ print (" # splits:" , np.sum(list (skeleton_metric.split_cnt.values())))
95+ print (" # merges:" , np.sum(list (skeleton_metric.merge_cnt.values())))
96+
97+ # Save results
98+ path = f " { output_dir} /evaluation_results.xls "
99+ save_results(path, full_results)
100+
101+
102+ if __name__ == " __main__" :
103+ # Initializations
104+ output_dir = " ./"
105+ segmentation_path = " ./pred_labels.tif"
106+ fragments_pointer = " ./pred_swcs.zip"
107+ groundtruth_pointer = " ./target_swcs.zip"
108+
109+ # Run
110+ evaluate()
104111
105112```
106113
0 commit comments