Skip to content

Commit 9c66830

Browse files
committed
fix more ty errors
1 parent dbf22c9 commit 9c66830

File tree

9 files changed

+23
-16
lines changed

9 files changed

+23
-16
lines changed

tests/process_block/argmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_argmax(tmpdir):
3636
# create labels array
3737
prepare_ds(
3838
tmpdir / "test_data.zarr" / "labels",
39-
shape=probs_data[1:],
39+
shape=probs_data.shape[1:],
4040
voxel_size=Coordinate(1, 1),
4141
dtype=np.uint32,
4242
mode="w",

tests/process_block/graph_mws.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def test_graph_mws(tmpdir, y_bias: float):
5656
with config.process_block_func() as process_block:
5757
process_block(block)
5858

59-
fragments, segments = config.lut.load()
59+
lut = config.lut.load()
60+
assert lut is not None
61+
fragments, segments = lut
6062
assert len(np.unique(fragments)) == 2, fragments
6163
assert len(np.unique(segments)) == 1 + (y_bias < 0), segments

tests/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_attrs_affs_validation(tmp_path):
217217
# Case 3: Mismatch -> Error
218218
with pytest.raises(ValidationError):
219219
# Disk has [[0,1], [1,0]], we provide [[5,5]]
220-
Affs(store=path, neighborhood=[[5, 5]])
220+
Affs(store=path, neighborhood=[Coordinate(5, 5)])
221221

222222

223223
def test_attrs_labels_lsd():

volara/blockwise/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .seeded_extract_frags import SeededExtractFrags as SeededExtractFrags
1919
from .threshold import Threshold as Threshold
2020

21-
BLOCKWISE_TASKS = []
21+
BLOCKWISE_TASKS: list[BlockwiseTask] = []
2222

2323

2424
def register_task(task: BlockwiseTask):
@@ -57,7 +57,7 @@ def get_blockwise_tasks_type():
5757
TASKS_DISCOVERED = True
5858
return TypeAdapter(
5959
Annotated[
60-
Union[tuple(BLOCKWISE_TASKS)],
60+
Union[tuple[BLOCKWISE_TASKS]],
6161
Field(discriminator="task_type"),
6262
]
6363
)

volara/blockwise/graph_mws.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def process_block_func(self):
117117
out_rag_provider = self.out_db.open("w")
118118

119119
if self.starting_lut is not None:
120-
starting_frags, starting_segs = self.starting_lut.load()
120+
starting_lut = self.starting_lut.load()
121+
assert starting_lut is not None, "Unable to load starting LUT"
122+
starting_frags, starting_segs = starting_lut
121123
starting_map = {
122124
in_frag: out_frag
123125
for in_frag, out_frag in zip(starting_frags, starting_segs)

volara/datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class Raw(Dataset):
196196
def bounds(self) -> list[tuple[float, float]] | None:
197197
if self.ome_norm is not None:
198198
array = open_ds(self.store, mode="r", **self.zarr_kwargs)
199-
metadata_group = zarr.open(self.ome_norm)
199+
metadata_group = zarr.open(str(self.ome_norm))
200200
channels_meta = metadata_group.attrs["omero"]["channels"]
201201
bounds = [
202202
(channels_meta[c]["window"]["min"], channels_meta[c]["window"]["max"])
@@ -339,8 +339,8 @@ def array(self, mode: str = "r") -> Array:
339339

340340
if hasattr(vol, "to_dask") and callable(vol.to_dask):
341341
return Array(
342-
vol.to_dask(),
343-
**{k: v for k, v in metadata.items() if v is not None}, # type: ignore[invalid-argument]
342+
vol.to_dask(), # type: ignore
343+
**{k: v for k, v in metadata.items() if v is not None}, # type: ignore
344344
)
345345
else:
346346
raise Exception(

volara/dbs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def open(self, mode="r") -> PgSQLGraphDatabase:
247247
mode=mode,
248248
)
249249

250-
def spoof(self):
250+
def spoof(self, nodes: bool = True):
251251
raise NotImplementedError(
252252
"Spoofing PostgreSQL databases is not implemented yet."
253253
)

volara/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# default log dir
66
LOG_BASEDIR = Path("./volara_logs")
7-
daisy.logging.set_log_basedir(LOG_BASEDIR) # type: ignore[unresolved-attribute]
7+
daisy.logging.set_log_basedir(LOG_BASEDIR)
88

99

1010
def set_log_basedir(path: Path | str):
@@ -24,7 +24,7 @@ def set_log_basedir(path: Path | str):
2424
raise NotImplementedError("None is not a valid log directory")
2525
LOG_BASEDIR = None
2626

27-
daisy.logging.set_log_basedir(LOG_BASEDIR) # type: ignore[unresolved-attribute]
27+
daisy.logging.set_log_basedir(LOG_BASEDIR)
2828

2929

3030
def get_log_basedir():

volara/lut.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def drop(self):
4040
self.file.unlink()
4141

4242
def save(self, lut: np.ndarray, edges=None):
43-
np.savez_compressed(
44-
self.file, fragment_segment_lut=lut.astype(int), edges=edges
45-
)
43+
if edges is not None:
44+
np.savez_compressed(
45+
self.file, fragment_segment_lut=lut.astype(int), edges=edges
46+
)
47+
else:
48+
np.savez_compressed(self.file, fragment_segment_lut=lut.astype(int))
4649

4750
def load(self) -> np.ndarray | None:
4851
if not self.file.exists():
@@ -73,7 +76,7 @@ def __add__(self, other):
7376
def load(self):
7477
return np.concatenate(
7578
[lut.load() for lut in self.luts if lut.load() is not None], axis=1
76-
)
79+
) # type: ignore
7780

7881
def load_iterated(self):
7982
starting_map = self.luts[0].load()

0 commit comments

Comments
 (0)