Skip to content

Commit 8041ecc

Browse files
authored
Update README.md
1 parent a8e5eb0 commit 8041ecc

File tree

1 file changed

+35
-42
lines changed

1 file changed

+35
-42
lines changed

README.md

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,48 +66,41 @@ Note: Misalignments between the ground truth graphs and prediction segmentation
6666
Here is a simple example of evaluating a predicted segmentation.
6767

6868
```python
69-
import numpy as np
70-
from xlwt import Workbook
71-
72-
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
73-
from segmentation_skeleton_metrics.utils.img_util import TiffReader
74-
75-
76-
def evaluate():
77-
# Initializations
78-
segmentation = TiffReader(segmentation_path)
79-
skeleton_metric = SkeletonMetric(
80-
groundtruth_pointer,
81-
segmentation,
82-
fragments_pointer=fragments_pointer,
83-
output_dir=output_dir,
84-
)
85-
full_results, avg_results = skeleton_metric.run()
86-
87-
# Report results
88-
print(f"\nAveraged Results...")
89-
for key in avg_results.keys():
90-
print(f" {key}: {round(avg_results[key], 4)}")
91-
92-
print(f"\nTotal Results...")
93-
print("# splits:", np.sum(list(skeleton_metric.split_cnt.values())))
94-
print("# merges:", np.sum(list(skeleton_metric.merge_cnt.values())))
95-
96-
# Save results
97-
path = f"{output_dir}/evaluation_results.xls"
98-
save_results(path, full_results)
99-
100-
101-
if __name__ == "__main__":
102-
# Initializations
103-
output_dir = "./"
104-
segmentation_path = "./pred_labels.tif"
105-
fragments_pointer = "./pred_swcs.zip"
106-
groundtruth_pointer = "./target_swcs.zip"
107-
108-
# Run
109-
evaluate()
110-
69+
70+
import torch.nn as nn
71+
import torch.optim as optim
72+
73+
from supervoxel_loss.loss import SuperVoxelLoss2D
74+
75+
76+
# Initialization
77+
model = UNet()
78+
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
79+
80+
loss_switch_epoch = 10
81+
voxel_loss = nn.BCEWithLogitsLoss()
82+
supervoxel_loss = SuperVoxelLoss2D(alpha=0.5, beta=0.5, threshold=0)
83+
84+
# Main
85+
for epoch in range(n_epochs):
86+
# Set loss function based on the current epoch
87+
if epoch < loss_switch_epoch:
88+
loss_function = voxel_loss
89+
else:
90+
loss_function = supervoxel_loss
91+
92+
# Training loop
93+
for inputs, targets in dataloader:
94+
# Forward pass
95+
preds = model(inputs)
96+
97+
# Compute loss
98+
loss = loss_function(preds, targets)
99+
100+
# Backward pass
101+
optimizer.zero_grad()
102+
loss.backward()
103+
optimizer.step()
111104

112105
```
113106

0 commit comments

Comments
 (0)