Skip to content

Commit 7debe0c

Browse files
authored
Update README.md
1 parent 9147553 commit 7debe0c

File tree

1 file changed

+41
-34
lines changed

1 file changed

+41
-34
lines changed

README.md

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -67,40 +67,47 @@ Here is a simple example of evaluating a predicted segmentation.
6767

6868
```python
6969

70-
import torch.nn as nn
71-
import torch.optim as optim
72-
73-
from supervoxel_loss.loss import SuperVoxelLoss2D
74-
75-
76-
# Initialization
77-
model = UNet()
78-
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
79-
80-
loss_switch_epoch = 10
81-
voxel_loss = nn.BCEWithLogitsLoss()
82-
supervoxel_loss = SuperVoxelLoss2D(alpha=0.5, beta=0.5, threshold=0)
83-
84-
# Main
85-
for epoch in range(n_epochs):
86-
# Set loss function based on the current epoch
87-
if epoch < loss_switch_epoch:
88-
loss_function = voxel_loss
89-
else:
90-
loss_function = supervoxel_loss
91-
92-
# Training loop
93-
for inputs, targets in dataloader:
94-
# Forward pass
95-
preds = model(inputs)
96-
97-
# Compute loss
98-
loss = loss_function(preds, targets)
99-
100-
# Backward pass
101-
optimizer.zero_grad()
102-
loss.backward()
103-
optimizer.step()
70+
import numpy as np
71+
from xlwt import Workbook
72+
73+
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
74+
from segmentation_skeleton_metrics.utils.img_util import TiffReader
75+
76+
77+
def evaluate():
78+
# Initializations
79+
segmentation = TiffReader(segmentation_path)
80+
skeleton_metric = SkeletonMetric(
81+
groundtruth_pointer,
82+
segmentation,
83+
fragments_pointer=fragments_pointer,
84+
output_dir=output_dir,
85+
)
86+
full_results, avg_results = skeleton_metric.run()
87+
88+
# Report results
89+
print(f"\nAveraged Results...")
90+
for key in avg_results.keys():
91+
print(f" {key}: {round(avg_results[key], 4)}")
92+
93+
print(f"\nTotal Results...")
94+
print("# splits:", np.sum(list(skeleton_metric.split_cnt.values())))
95+
print("# merges:", np.sum(list(skeleton_metric.merge_cnt.values())))
96+
97+
# Save results
98+
path = f"{output_dir}/evaluation_results.xls"
99+
save_results(path, full_results)
100+
101+
102+
if __name__ == "__main__":
103+
# Initializations
104+
output_dir = "./"
105+
segmentation_path = "./pred_labels.tif"
106+
fragments_pointer = "./pred_swcs.zip"
107+
groundtruth_pointer = "./target_swcs.zip"
108+
109+
# Run
110+
evaluate()
104111

105112
```
106113

0 commit comments

Comments
 (0)