Skip to content

Commit cbdfc34

Browse files
author
anna-grim
committed
added documentation
1 parent f7045ce commit cbdfc34

File tree

4 files changed

+113
-25
lines changed

4 files changed

+113
-25
lines changed

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SkeletonGraph(nx.Graph):
3939
A 3D array that contains a voxel coordinate for each node.
4040
4141
"""
42+
4243
colors = [
4344
"# COLOR 1.0 0.0 1.0", # pink
4445
"# COLOR 0.0 1.0 1.0", # cyan
@@ -375,8 +376,20 @@ def to_zipped_swc(self, zip_writer):
375376
zip_writer.writestr(self.filename, text_buffer.getvalue())
376377

377378
def get_color(self):
379+
"""
380+
Gets the display color of the skeleton to be written to an SWC file.
381+
382+
Parameters
383+
----------
384+
None
385+
386+
Returns
387+
-------
388+
str
389+
String representing the color in the format "# COLOR R G B".
390+
391+
"""
378392
if self.is_groundtruth:
379-
return "# COLOR 1.0 1.0 1.0"
393+
return "# COLOR 1.0 1.0 1.0"
380394
else:
381395
return util.sample_once(SkeletonGraph.colors)
382-

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ class SkeletonMetric:
4949
(7) Expected Run Length (ERL)
5050
(8) Normalized ERL
5151
52-
Class attributes
53-
----------------
54-
merge_dist : float
55-
...
56-
min_label_cnt : int
57-
...
58-
5952
"""
6053

6154
def __init__(
@@ -146,7 +139,7 @@ def __init__(
146139
"Edge Accuracy",
147140
"ERL",
148141
"Normalized ERL",
149-
"GT Run Length"
142+
"GT Run Length",
150143
]
151144
self.metrics = pd.DataFrame(index=row_names, columns=col_names)
152145

@@ -561,7 +554,7 @@ def is_fragment_merge(self, key, label, kdtree):
561554
562555
"""
563556
for fragment_graph in self.find_graph_from_label(label):
564-
if fragment_graph.run_length < 10**6:
557+
if fragment_graph.run_length < 10 ** 6:
565558
# Search for leaf far from ground truth
566559
visited = set()
567560
for leaf in gutil.get_leafs(fragment_graph):
@@ -579,8 +572,11 @@ def is_fragment_merge(self, key, label, kdtree):
579572
)
580573
else:
581574
segment_id = util.get_segment_id(fragment_graph.filename)
575+
run_length = fragment_graph.run_length
582576
self.merged_labels.add((key, segment_id, -1))
583-
print(f"Skipping {segment_id} - run_length={fragment_graph.run_length}")
577+
print(
578+
f"Skipping {segment_id} - run_length={run_length}"
579+
)
584580

585581
def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
586582
for _, node in nx.dfs_edges(fragment_graph, source=source):
@@ -605,7 +601,9 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
605601
"Segment_ID": segment_id,
606602
"GroundTruth_ID": key,
607603
"Voxel": tuple([int(t) for t in voxel]),
608-
"World": tuple([float(round(t, 2)) for t in xyz]),
604+
"World": tuple(
605+
[float(round(t, 2)) for t in xyz]
606+
),
609607
}
610608
)
611609

@@ -615,8 +613,8 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
615613
fragment_graph, self.merge_writer
616614
)
617615
gutil.write_graph(
618-
self.gt_graphs[key], self.merge_writer
619-
)
616+
self.gt_graphs[key], self.merge_writer
617+
)
620618
return
621619

622620
def is_valid_merge(self, graph, kdtree, root):
@@ -643,7 +641,7 @@ def is_valid_merge(self, graph, kdtree, root):
643641
queue.append((j, d_j))
644642
visited.add(j)
645643
return True if n_hits > 16 else False
646-
644+
647645
def process_merge_sites(self):
648646
if self.merge_sites:
649647
# Remove duplicates

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def read_from_paths(self, paths):
178178
for path in paths:
179179
filename = os.path.basename(path)
180180
if self.confirm_read(filename):
181-
threads.append(
182-
executor.submit(self.read_from_path, path)
183-
)
181+
threads.append(executor.submit(self.read_from_path, path))
184182

185183
# Store results
186184
swc_dicts = deque()
@@ -217,9 +215,7 @@ def read_from_zips(self, zip_dir):
217215
processes = list()
218216
for f in zip_names:
219217
zip_path = os.path.join(zip_dir, f)
220-
processes.append(
221-
executor.submit(self.read_from_zip, zip_path)
222-
)
218+
processes.append(executor.submit(self.read_from_zip, zip_path))
223219

224220
# Store results
225221
swc_dicts = deque()
@@ -318,6 +314,22 @@ def read_from_gcs(self, gcs_dict):
318314
raise Exception(f"GCS Pointer is invalid -{gcs_dict}-")
319315

320316
def read_from_gcs_swcs(self, bucket_name, swc_paths):
317+
"""
318+
Reads SWC files stored in a GCS bucket.
319+
320+
Parameters
321+
----------
322+
gcs_dict : dict
323+
Dictionary with the keys "bucket_name" and "path" that specify
324+
where the SWC files are located in a GCS bucket.
325+
326+
Returns
327+
-------
328+
Dequeue[dict]
329+
List of dictionaries whose keys and values are the attribute
330+
names and values from an SWC file.
331+
332+
"""
321333
pbar = tqdm(total=len(swc_paths), desc="Read SWCs")
322334
with ThreadPoolExecutor() as executor:
323335
# Assign threads
@@ -337,6 +349,22 @@ def read_from_gcs_swcs(self, bucket_name, swc_paths):
337349
return swc_dicts
338350

339351
def read_from_gcs_swc(self, bucket_name, path):
352+
"""
353+
Reads a single SWC file stored in a GCS bucket.
354+
355+
Parameters
356+
----------
357+
gcs_dict : dict
358+
Dictionary with the keys "bucket_name" and "path" that specify
359+
where a single SWC file is located in a GCS bucket.
360+
361+
Returns
362+
-------
363+
dict
364+
Dictionaries whose keys and values are the attribute names and
365+
values from an SWC file.
366+
367+
"""
340368
# Initialize cloud reader
341369
client = storage.Client()
342370
bucket = client.bucket(bucket_name)
@@ -348,6 +376,22 @@ def read_from_gcs_swc(self, bucket_name, path):
348376
return self.parse(content, filename)
349377

350378
def read_from_gcs_zips(self, bucket_name, zip_paths):
379+
"""
380+
Reads SWC files stored in a ZIP archives stored in a GCS bucket.
381+
382+
Parameters
383+
----------
384+
zip_content : bytes
385+
Content of a ZIP archive.
386+
387+
Returns
388+
-------
389+
Dequeue[dict]
390+
List of dictionaries whose keys and values are the attribute
391+
names and values from an SWC file.
392+
393+
394+
"""
351395
swc_dicts = deque()
352396
for zip_path in tqdm(zip_paths, desc="Read SWCs"):
353397
swc_dicts.extend(self.read_from_gcs_zip(bucket_name, zip_path))

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def update_txt(path, text):
167167
168168
"""
169169
print(text)
170-
with open(path, 'a') as file:
170+
with open(path, "a") as file:
171171
file.write(text + "\n")
172172

173173

@@ -248,14 +248,47 @@ def list_gcs_subdirectories(bucket_name, prefix):
248248
return subdirs
249249

250250

251-
def read_txt_from_gcs(bucket_name, file_name):
251+
def read_txt_from_gcs(bucket_name, filename):
252+
"""
253+
Reads a txt file stored in a GCS bucket.
254+
255+
Parameters
256+
----------
257+
bucket_name : str
258+
Name of bucket to be read from.
259+
filename : str
260+
Name of txt file to be read.
261+
262+
Returns
263+
-------
264+
str
265+
Contents of txt file.
266+
267+
"""
252268
client = storage.Client()
253269
bucket = client.bucket(bucket_name)
254-
blob = bucket.blob(file_name)
270+
blob = bucket.blob(filename)
255271
return blob.download_as_text()
256272

257273

258274
def upload_directory_to_gcs(bucket_name, source_dir, destination_dir):
275+
"""
276+
Uploads the contents of a local directory to a GCS bucket.
277+
278+
Parameters
279+
----------
280+
bucket_name : str
281+
Name of bucket to be read from.
282+
source_dir : str
283+
Path to the local directory whose contents should be uploaded.
284+
destination_dir : str
285+
Prefix path in the GCS bucket under which the files will be stored.
286+
287+
Returns
288+
-------
289+
None
290+
291+
"""
259292
client = storage.Client()
260293
bucket = client.bucket(bucket_name)
261294
for root, _, files in os.walk(source_dir):

0 commit comments

Comments
 (0)