Skip to content

Commit f6853d5

Browse files
Implement inference test
1 parent fba6486 commit f6853d5

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

test/test_cli.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from synapse_net.sample_data import get_sample_data
1313

1414

15+
@unittest.skipIf(platform.system() == "Windows", "CLI does not work on Windows")
1516
class TestCLI(unittest.TestCase):
1617
tmp_dir = "tmp"
1718

@@ -72,17 +73,10 @@ def test_segmentation_cli_with_scale(self):
7273
def test_segmentation_cli_with_checkpoint(self):
7374
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
7475
model_path = os.path.join(cache_dir, "models", "vesicles_2d")
75-
if platform.system() == "Windows":
76-
cmd = [
77-
sys.executable, "-m", "synapse_net.run_segmentation",
78-
"-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
79-
"-c", model_path,
80-
]
81-
else:
82-
cmd = [
83-
"synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
84-
"-c", model_path,
85-
]
76+
cmd = [
77+
"synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d",
78+
"-c", model_path,
79+
]
8680
run(cmd)
8781
self.check_segmentation_result()
8882

test/test_inference.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import unittest
3+
from functools import partial
4+
from shutil import rmtree
5+
6+
import imageio.v3 as imageio
7+
from synapse_net.file_utils import read_mrc
8+
from synapse_net.sample_data import get_sample_data
9+
10+
11+
class TestInference(unittest.TestCase):
12+
tmp_dir = "tmp"
13+
model_type = "vesicles_2d"
14+
tiling = {"tile": {"z": 1, "y": 512, "x": 512}, "halo": {"z": 0, "y": 32, "x": 32}}
15+
16+
def setUp(self):
17+
self.data_path = get_sample_data("tem_2d")
18+
os.makedirs(self.tmp_dir, exist_ok=True)
19+
20+
def tearDown(self):
21+
try:
22+
rmtree(self.tmp_dir)
23+
except OSError:
24+
pass
25+
26+
def test_run_segmentation(self):
27+
from synapse_net.inference import run_segmentation, get_model
28+
29+
image, _ = read_mrc(self.data_path)
30+
model = get_model(self.model_type)
31+
seg = run_segmentation(image, model, model_type=self.model_type, tiling=self.tiling)
32+
self.assertEqual(image.shape, seg.shape)
33+
34+
def test_segmentation_with_inference_helper(self):
35+
from synapse_net.inference import run_segmentation, get_model
36+
from synapse_net.inference.util import inference_helper
37+
38+
model = get_model(self.model_type)
39+
segmentation_function = partial(
40+
run_segmentation, model=model, model_type=self.model_type, verbose=False, tiling=self.tiling,
41+
)
42+
inference_helper(self.data_path, self.tmp_dir, segmentation_function, data_ext=".mrc")
43+
expected_output_path = os.path.join(self.tmp_dir, "tem_2d_prediction.tif")
44+
self.assertTrue(os.path.exists(expected_output_path))
45+
seg = imageio.imread(expected_output_path)
46+
image, _ = read_mrc(self.data_path)
47+
self.assertEqual(image.shape, seg.shape)
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()

0 commit comments

Comments
 (0)