|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +from pathlib import Path |
| 4 | +import numpy as np |
| 5 | +import xarray as xr |
| 6 | +import spatialdata as sd |
| 7 | +from csbdeep.utils import normalize |
| 8 | +from stardist.models import StarDist2D |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | +def convert_to_lower_dtype(arr): |
| 13 | + max_val = arr.max() |
| 14 | + if max_val <= np.iinfo(np.uint8).max: |
| 15 | + new_dtype = np.uint8 |
| 16 | + elif max_val <= np.iinfo(np.uint16).max: |
| 17 | + new_dtype = np.uint16 |
| 18 | + elif max_val <= np.iinfo(np.uint32).max: |
| 19 | + new_dtype = np.uint32 |
| 20 | + else: |
| 21 | + new_dtype = np.uint64 |
| 22 | + |
| 23 | + return arr.astype(new_dtype) |
| 24 | + |
| 25 | +## VIASH START |
| 26 | +par = { |
| 27 | + "input": "./resources_test/common/2023_10x_mouse_brain_xenium_rep1/dataset.zarr", |
| 28 | + "output": "./temp/stardist/segmentation.zarr", |
| 29 | + "model": "2D_versatile_fluo" |
| 30 | +} |
| 31 | + |
| 32 | +## VIASH END |
| 33 | + |
| 34 | + |
| 35 | +# Read image and its transformation |
| 36 | +sdata = sd.read_zarr(par["input"]) |
| 37 | +image = sdata['morphology_mip']['scale0'].image.compute().to_numpy() |
| 38 | +transformation = sdata['morphology_mip']['scale0'].image.transform.copy() |
| 39 | + |
| 40 | +# Segment image |
| 41 | +# Load pretrained model |
| 42 | +model = StarDist2D.from_pretrained(par['model']) |
| 43 | +# Segment on normalized image |
| 44 | +labels, _ = model.predict_instances(normalize(image)[0,:,:]) # scale = None, **hyperparams) |
| 45 | + |
| 46 | + |
| 47 | +# Create output |
| 48 | +sd_output = sd.SpatialData() |
| 49 | +labels = convert_to_lower_dtype(labels) |
| 50 | +labels_array = xr.DataArray(labels, name=f'segmentation', dims=('y', 'x')) |
| 51 | +parsed_labels = sd.models.Labels2DModel.parse(labels_array, transformations=transformation) |
| 52 | +sd_output.labels['segmentation'] = parsed_labels |
| 53 | + |
| 54 | +print("Writing output", flush=True) |
| 55 | +Path(par["output"]).parent.mkdir(parents=True, exist_ok=True) |
| 56 | +if os.path.exists(par["output"]): |
| 57 | + shutil.rmtree(par["output"]) |
| 58 | +sd_output.write(par["output"]) |
| 59 | + |
0 commit comments