Skip to content

Commit 27a08e6

Browse files
committed
Adjust stardist for large image inputs
1 parent 6ea85bf commit 27a08e6

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

src/methods_segmentation/stardist/config.vsh.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ resources:
2121
path: script.py
2222

2323
engines:
24+
# NOTE: On mac the tensorflow install leads to an error. Develop the method in a conda env instead (and test docker via gh-actions).
25+
# Installations can be done with pip (except tensorflow: use conda install -c conda-forge tensorflow)
2426
- type: docker
2527
image: openproblems/base_python:1
2628
setup:

src/methods_segmentation/stardist/script.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import numpy as np
55
import xarray as xr
66
import spatialdata as sd
7-
from csbdeep.utils import normalize
7+
#from csbdeep.utils import normalize
8+
from csbdeep.data import Normalizer, normalize_mi_ma
89
from stardist.models import StarDist2D
910

1011

@@ -24,8 +25,8 @@ def convert_to_lower_dtype(arr):
2425

2526
## VIASH START
2627
par = {
27-
"input": "./resources_test/common/2023_10x_mouse_brain_xenium_rep1/dataset.zarr",
28-
"output": "./temp/stardist/segmentation.zarr",
28+
"input": "resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr",
29+
"output": "temp/stardist/segmentation.zarr",
2930
"model": "2D_versatile_fluo"
3031
}
3132

@@ -38,10 +39,34 @@ def convert_to_lower_dtype(arr):
3839
transformation = sdata['morphology_mip']['scale0'].image.transform.copy()
3940

4041
# Segment image
42+
4143
# Load pretrained model
4244
model = StarDist2D.from_pretrained(par['model'])
45+
4346
# Segment on normalized image
44-
labels, _ = model.predict_instances(normalize(image)[0,:,:]) # scale = None, **hyperparams)
47+
#labels, _ = model.predict_instances(normalize(image)[0,:,:]) # scale = None, **hyperparams)
48+
49+
# from https://github.com/stardist/stardist/blob/main/examples/other2D/predict_big_data.ipynb
50+
class MyNormalizer(Normalizer):
51+
def __init__(self, mi, ma):
52+
self.mi, self.ma = mi, ma
53+
def before(self, x, axes):
54+
return normalize_mi_ma(x, self.mi, self.ma, dtype=np.float32)
55+
def after(*args, **kwargs):
56+
assert False
57+
@property
58+
def do_after(self):
59+
return False
60+
61+
mi, ma = np.percentile(image, [1,99.8])
62+
normalizer = MyNormalizer(mi, ma)
63+
block_size = min(image.shape[1] // 3, 4096)
64+
offset = min(block_size // 5.5, 128)
65+
66+
labels, _ = model.predict_instances_big(
67+
image[0,:,:], axes='YX', block_size=block_size, min_overlap=offset, context=offset, normalizer=normalizer#, n_tiles=(4,4)
68+
)
69+
4570

4671

4772
# Create output

0 commit comments

Comments
 (0)