Skip to content

Commit 7173df3

Browse files
author
will
committed
lint and format volara and tests
1 parent a682d68 commit 7173df3

File tree

5 files changed

+23
-36
lines changed

5 files changed

+23
-36
lines changed

tests/test_dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
from pathlib import Path
2+
3+
import numpy as np
14
import pytest
25
import yaml
3-
import numpy as np
46
import zarr
5-
from pathlib import Path
67
from funlib.geometry import Coordinate
7-
8-
from volara.datasets import Raw, Affs, LSD, Labels
98
from pydantic import ValidationError
109

10+
from volara.datasets import LSD, Affs, Labels, Raw
11+
1112

1213
@pytest.fixture
1314
def zarr_store(tmp_path):
@@ -77,7 +78,7 @@ def test_affs_serialization(tmp_path):
7778
yaml_data = f"""
7879
dataset_type: affs
7980
store: {tmp_path / "affs.zarr"}
80-
neighborhood:
81+
neighborhood:
8182
- [0, 1, 0]
8283
- [0, 0, 1]
8384
"""
@@ -109,15 +110,15 @@ def test_lazy_channel_slicing(zarr_store):
109110
assert data.shape == (10, 10)
110111
assert np.all(data == 1.0) # Original data was 1s
111112

112-
ds = Raw(store=zarr_store, channels=[0, [0,2,4,6,8]])
113+
ds = Raw(store=zarr_store, channels=[0, [0, 2, 4, 6, 8]])
113114

114115
arr = ds.array()
115116
data = arr[:]
116117

117118
assert data.shape == (5, 10)
118119
assert np.all(data == 1.0) # Original data was 1s
119120

120-
ds = Raw(store=zarr_store, channels=[0, 0, [0,2,4,6,8]])
121+
ds = Raw(store=zarr_store, channels=[0, 0, [0, 2, 4, 6, 8]])
121122

122123
arr = ds.array()
123124
data = arr[:]

volara/blockwise/blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def process_func(block):
399399
yield task
400400

401401
def get_benchmark_logger(self) -> BenchmarkLogger:
402-
benchmark_db_path = Path("volara_benchmark_logs/benchmark.db")
402+
_benchmark_db_path = Path("volara_benchmark_logs/benchmark.db")
403403
return BenchmarkLogger(
404404
None,
405405
task=self.task_name,

volara/blockwise/graph_mws.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
from contextlib import contextmanager
2-
from typing import Annotated, Literal
3-
import tempfile
4-
import itertools
51
import functools
2+
import itertools
3+
import tempfile
4+
from contextlib import contextmanager
65
from pathlib import Path
7-
import time
8-
from pprint import pprint
6+
from typing import Annotated, Literal
97

108
import daisy
119
import mwatershed as mws
12-
import numpy as np
1310
import networkx as nx
11+
import numpy as np
1412
from funlib.geometry import Coordinate, Roi
1513
from pydantic import Field
1614

@@ -20,7 +18,6 @@
2018
from ..dbs import PostgreSQL, SQLite
2119
from ..utils import PydanticCoordinate
2220
from .blockwise import BlockwiseTask
23-
from ..datasets import Labels
2421

2522
DB = Annotated[
2623
PostgreSQL | SQLite,
@@ -313,7 +310,6 @@ def process_block(block: daisy.Block):
313310
both_sides=True,
314311
)
315312

316-
t1 = time.time()
317313
edges = []
318314
inputs = set(
319315
node
@@ -391,15 +387,13 @@ def process_block(block: daisy.Block):
391387
in_group = new_seg_to_frag_mapping.setdefault(int(out_frag), set())
392388
in_group.add(int(in_frag))
393389

394-
t1 = time.time()
395390
# save the lut to a temporary file for this block
396391
block_lut = LUT(
397392
path=f"{tmp_path}/{'-'.join([str(o) for o in block.write_roi.offset])}-lut"
398393
)
399394
lut = np.array([inputs, outputs])
400395
block_lut.save(lut, edges=edges)
401396

402-
t1 = time.time()
403397
# read luts and existing super fragments in neighboring blocks
404398
existing_luts = [
405399
LUT(
@@ -420,7 +414,6 @@ def process_block(block: daisy.Block):
420414
+ [self.lut],
421415
).load()
422416

423-
t1 = time.time()
424417
frag_seg_mapping: dict[int, int] = {
425418
int(k): int(v) for k, v in total_lut.T
426419
}
@@ -429,10 +422,8 @@ def process_block(block: daisy.Block):
429422
in_group = seg_frag_mapping.setdefault(out_frag, set())
430423
in_group.add(in_frag)
431424

432-
t1 = time.time()
433425
out_graph = out_rag_provider.read_graph(block.read_roi)
434426

435-
t1 = time.time()
436427
for out_seg, in_frags in new_seg_to_frag_mapping.items():
437428
agglomerated_attrs = {
438429
"size": sum(
@@ -454,7 +445,6 @@ def process_block(block: daisy.Block):
454445

455446
out_graph.add_node(int(out_seg), **agglomerated_attrs)
456447

457-
t1 = time.time()
458448
edges_to_agglomerate = {}
459449
for u, v in graph.edges():
460450
if (
@@ -601,7 +591,6 @@ def process_block(block: daisy.Block):
601591
edge_attrs=list(self.weights.keys()),
602592
)
603593

604-
t1 = time.time()
605594
edges = []
606595
for u, v, edge_attrs in graph.edges(data=True):
607596
if (
@@ -630,15 +619,13 @@ def process_block(block: daisy.Block):
630619
else:
631620
inputs, outputs = [], []
632621

633-
t1 = time.time()
634622
# save the lut to a temporary file for this block
635623
block_lut = LUT(
636624
path=f"{tmp_path}/{'-'.join([str(o) for o in block.write_roi.offset])}-lut"
637625
)
638626
lut = np.array([inputs, outputs])
639627
block_lut.save(lut, edges=edges)
640628

641-
t1 = time.time()
642629
# read luts and existing super fragments in neighboring blocks
643630
existing_luts = [
644631
LUT(
@@ -659,7 +646,6 @@ def process_block(block: daisy.Block):
659646
+ [self.lut],
660647
).load()
661648

662-
t1 = time.time()
663649
frag_seg_mapping: dict[int, int] = {
664650
int(k): int(v) for k, v in total_lut.T
665651
}
@@ -668,11 +654,9 @@ def process_block(block: daisy.Block):
668654
in_group = seg_frag_mapping.setdefault(out_frag, set())
669655
in_group.add(in_frag)
670656

671-
t1 = time.time()
672657
out_graph = out_rag_provider.read_graph(block.read_roi)
673658
assert out_graph.number_of_nodes() == 0, out_graph.number_of_nodes
674659

675-
t1 = time.time()
676660
for out_frag in np.unique(outputs):
677661
if out_frag is not None and out_frag not in out_graph.nodes:
678662
in_group = seg_frag_mapping[out_frag]
@@ -699,7 +683,6 @@ def process_block(block: daisy.Block):
699683

700684
out_graph.add_node(int(out_frag), **agglomerated_attrs)
701685

702-
t1 = time.time()
703686
edges_to_agglomerate = {}
704687
for u, v in graph.edges():
705688
if u in frag_seg_mapping and v in frag_seg_mapping:

volara/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def stack(data, other_data):
242242
arr.lazy_op(lambda data: stack(data, self.stack.array("r").data)) # type: ignore[possibly-missing-attribute]
243243

244244

245-
246245
class Affs(Dataset):
247246
"""
248247
Represents a dataset containing affinities.

volara/lut.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from volara.tmp import replace_values
2-
3-
from pathlib import Path
41
from collections.abc import Sequence
2+
from pathlib import Path
53

64
import numpy as np
75

6+
from volara.tmp import replace_values
7+
88
from .utils import StrictBaseModel
99

1010

@@ -38,7 +38,9 @@ def drop(self):
3838
self.file.unlink()
3939

4040
def save(self, lut: np.ndarray, edges=None):
41-
np.savez_compressed(self.file, fragment_segment_lut=lut.astype(int), edges=edges)
41+
np.savez_compressed(
42+
self.file, fragment_segment_lut=lut.astype(int), edges=edges
43+
)
4244

4345
def load(self) -> np.ndarray | None:
4446
if not self.file.exists():
@@ -76,5 +78,7 @@ def load_iterated(self):
7678
for lut in self.luts[1:]:
7779
next_map = lut.load()
7880
if next_map is not None:
79-
starting_map[1] = replace_values(starting_map[1], next_map[0], next_map[1])
81+
starting_map[1] = replace_values(
82+
starting_map[1], next_map[0], next_map[1]
83+
)
8084
return starting_map

0 commit comments

Comments
 (0)