Skip to content

Commit a17c6b2

Browse files
authored
major refactor
1 parent cf7663a commit a17c6b2

File tree

1 file changed

+104
-50
lines changed
  • src/segmentation_skeleton_metrics/utils

1 file changed

+104
-50
lines changed

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
1010
"""
1111

12+
from botocore import UNSIGNED
13+
from botocore.client import Config
1214
from random import sample
1315
from google.cloud import storage
14-
from io import BytesIO
15-
from xlwt import Workbook
16+
from io import BytesIO, StringIO
1617
from zipfile import ZipFile
1718

1819
import boto3
19-
from botocore import UNSIGNED
20-
from botocore.client import Config
2120
import os
2221
import pandas as pd
2322
import shutil
@@ -55,6 +54,11 @@ def rmdir(path):
5554
shutil.rmtree(path)
5655

5756

57+
def rm_file(path):
58+
if os.path.exists(path):
59+
os.remove(path)
60+
61+
5862
def list_dir(directory, extension=None):
5963
"""
6064
Lists filenames in the given directory. If "extension" is provided,
@@ -78,6 +82,24 @@ def list_dir(directory, extension=None):
7882
return [f for f in os.listdir(directory) if f.endswith(extension)]
7983

8084

85+
def list_files_in_zip(zip_content):
86+
"""
87+
Lists all files in a zip file stored in a GCS bucket.
88+
89+
Parameters
90+
----------
91+
zip_content : str
92+
Content stored in a ZIP archive in the form of a string of bytes.
93+
94+
Returns
95+
-------
96+
List[str]
97+
Filenames in a ZIP archive file.
98+
"""
99+
with ZipFile(BytesIO(zip_content), "r") as zip_file:
100+
return zip_file.namelist()
101+
102+
81103
def list_paths(directory, extension=None):
82104
"""
83105
Lists paths of files in the given directory. If "extension" is provided,
@@ -154,40 +176,81 @@ def update_txt(path, text):
154176
file.write(text + "\n")
155177

156178

157-
# -- GCS Utils --
158-
def is_gcs_path(path):
179+
# --- Graph Utils ---
180+
def get_leafs(graph):
159181
"""
160-
Checks if the path is a GCS path.
182+
Gets all leafs nodes in the given graph.
161183
162184
Parameters
163185
----------
164-
path : str
165-
Path to be checked.
186+
graph : networkx.Graph
187+
Graph to be searched.
166188
167189
Returns
168190
-------
169-
bool
170-
Indication of whether the path is a GCS path.
191+
List[int]
192+
Leaf nodes in the given graph.
171193
"""
172-
return path.startswith("gs://")
194+
return [node for node in graph.nodes if graph.degree[node] == 1]
173195

174196

175-
def list_files_in_zip(zip_content):
197+
def search_branching_node(graph, kdtree, root, radius=100):
176198
"""
177-
Lists all files in a zip file stored in a GCS bucket.
199+
Searches for a branching node within distance "radius" from the given
200+
root node.
178201
179202
Parameters
180203
----------
181-
zip_content : str
182-
Content stored in a ZIP archive in the form of a string of bytes.
204+
graph : networkx.Graph
205+
Graph to be searched.
206+
kdtree : scipy.spatial.KDTree
207+
KDTree containing voxel coordinates from a ground truth tracing.
208+
root : int
209+
Root of search.
210+
radius : float, optional
211+
Distance to search from root. Default is 100.
183212
184213
Returns
185214
-------
186-
List[str]
187-
Filenames in a ZIP archive file.
215+
int
216+
Root node or closest branching node within distance "radius".
217+
"""
218+
queue = list([(root, 0)])
219+
visited = set({root})
220+
while queue:
221+
# Visit node
222+
i, d_i = queue.pop()
223+
xyz_i = graph.get_xyz(i)
224+
if graph.degree[i] > 2:
225+
dist, _ = kdtree.query(xyz_i)
226+
if dist < 16:
227+
return i
228+
229+
# Update queue
230+
for j in graph.neighbors(i):
231+
d_j = d_i + graph.physical_dist(i, j)
232+
if j not in visited and d_j < radius:
233+
queue.append((j, d_j))
234+
visited.add(j)
235+
return root
236+
237+
238+
# -- GCS Utils --
239+
def is_gcs_path(path):
188240
"""
189-
with ZipFile(BytesIO(zip_content), "r") as zip_file:
190-
return zip_file.namelist()
241+
Checks if the path is a GCS path.
242+
243+
Parameters
244+
----------
245+
path : str
246+
Path to be checked.
247+
248+
Returns
249+
-------
250+
bool
251+
Indication of whether the path is a GCS path.
252+
"""
253+
return path.startswith("gs://")
191254

192255

193256
def list_gcs_filenames(bucket_name, prefix, extension):
@@ -499,37 +562,28 @@ def sample_once(my_container):
499562
return sample(my_container, 1)[0]
500563

501564

502-
def save_results(path, stats):
565+
def to_zipped_point(zip_writer, filename, xyz):
503566
"""
504-
Saves the evaluation results generated from skeleton-based metrics to an
505-
Excel file.
567+
Writes a point to an SWC file format, which is then stored in a ZIP
568+
archive.
506569
507570
Parameters
508571
----------
509-
path : str
510-
Path where the Excel file will be saved.
511-
stats : dict
512-
Dictionary where the keys are SWC IDs (as strings) and the values
513-
are dictionaries containing metrics as keys and their respective
514-
values.
515-
"""
516-
# Initialize
517-
wb = Workbook()
518-
sheet = wb.add_sheet("Results")
519-
sheet.write(0, 0, "swc_id")
520-
521-
# Label rows and columns
522-
swc_ids = list(stats.keys())
523-
for i, swc_id in enumerate(swc_ids):
524-
sheet.write(i + 1, 0, swc_id)
525-
526-
metrics = list(stats[swc_id].keys())
527-
for i, metric in enumerate(metrics):
528-
sheet.write(0, i + 1, metric)
529-
530-
# Write stats
531-
for i, swc_id in enumerate(swc_ids):
532-
for j, metric in enumerate(metrics):
533-
sheet.write(i + 1, j + 1, round(stats[swc_id][metric], 4))
534-
535-
wb.save(path)
572+
zip_writer : zipfile.ZipFile
573+
A ZipFile object that will store the generated SWC file.
574+
filename : str
575+
Filename of SWC file.
576+
xyz : ArrayLike
577+
Point to be written to SWC file.
578+
"""
579+
with StringIO() as text_buffer:
580+
# Preamble
581+
text_buffer.write("# COLOR 1.0 0.0 0.0")
582+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
583+
584+
# Write entry
585+
x, y, z = tuple(xyz)
586+
text_buffer.write("\n" + f"1 2 {x} {y} {z} 10 -1")
587+
588+
# Finish
589+
zip_writer.writestr(filename, text_buffer.getvalue())

0 commit comments

Comments
 (0)