Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 116 additions & 22 deletions flamingo_tools/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _measure_volume_and_surface(mask, resolution):
return volume, surface


def _get_bounding_box(table, seg_id, resolution, shape):
def _get_bounding_box_and_center(table, seg_id, resolution, shape):
row = table[table.label_id == seg_id]

bb_min = np.array([
Expand All @@ -46,38 +46,113 @@ def _get_bounding_box(table, seg_id, resolution, shape):
slice(max(bmin - 1, 0), min(bmax + 1, sh))
for bmin, bmax, sh in zip(bb_min, bb_max, shape)
)
return bb

center = (
int(row.anchor_z.item() / resolution),
int(row.anchor_y.item() / resolution),
int(row.anchor_x.item() / resolution),
)

return bb, center


def _spherical_mask(shape, radius, center=None):
if center is None:
center = tuple(s // 2 for s in shape)
if len(shape) != len(center):
raise ValueError("`shape` and `center` must have same length")

# Build a 1-D open grid for every axis
grids = np.ogrid[tuple(slice(0, s) for s in shape)]
dist2 = sum((g - c) ** 2 for g, c in zip(grids, center))
return (dist2 <= radius ** 2).astype(bool)

def _default_object_features(seg_id, table, image, segmentation, resolution):
bb = _get_bounding_box(table, seg_id, resolution, image.shape)

def _normalize_background(measures, image, mask, center, radius, norm, median_only):
# Compute the bounding box and get the local image data.
bb = tuple(
slice(max(0, int(ce - radius)), min(int(ce + radius), sh)) for ce, sh in zip(center, image.shape)
)
local_image = image[bb]

# Create a mask with radius around the center.
radius_mask = _spherical_mask(local_image.shape, radius)

# Intersect the radius mask with the foreground mask (if given).
if mask is not None:
assert mask.shape == image.shape, f"{mask.shape}, {image.shape}"
local_mask = mask[bb]
radius_mask = np.logical_and(radius_mask, local_mask)

# For debugging.
# import napari
# v = napari.Viewer()
# v.add_image(local_image)
# v.add_labels(local_mask)
# v.add_labels(radius_mask)
# napari.run()

# Compute the features over the mask.
masked_intensity = local_image[radius_mask]

# Standardize the measures.
bg_measures = {"median": np.median(masked_intensity)}
if not median_only:
bg_measures = {
"mean": np.mean(masked_intensity),
"stdev": np.std(masked_intensity),
"min": np.min(masked_intensity),
"max": np.max(masked_intensity),
}
for percentile in (5, 10, 25, 75, 90, 95):
bg_measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile)

for measure, val in bg_measures.items():
measures[measure] = norm(measures[measure], val)

return measures


def _default_object_features(
seg_id, table, image, segmentation, resolution,
foreground_mask=None, background_radius=None, norm=np.divide, median_only=False,
):
bb, center = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)

local_image = image[bb]
mask = segmentation[bb] == seg_id
assert mask.sum() > 0, f"Segmentation ID {seg_id} is empty."
masked_intensity = local_image[mask]

# Do the base intensity measurements.
measures = {
"label_id": seg_id,
"mean": np.mean(masked_intensity),
"stdev": np.std(masked_intensity),
"min": np.min(masked_intensity),
"max": np.max(masked_intensity),
"median": np.median(masked_intensity),
}
for percentile in (5, 10, 25, 75, 90, 95):
measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile)
measures = {"label_id": seg_id, "median": np.median(masked_intensity)}
if not median_only:
measures.update({
"mean": np.mean(masked_intensity),
"stdev": np.std(masked_intensity),
"min": np.min(masked_intensity),
"max": np.max(masked_intensity),
})
for percentile in (5, 10, 25, 75, 90, 95):
measures[f"percentile-{percentile}"] = np.percentile(masked_intensity, percentile)

if background_radius is not None:
# The radius passed is given in micrometer.
# The resolution is given in micrometer per pixel.
# So we have to divide by the resolution to obtain the radius in pixel.
radius_in_pixel = background_radius / resolution
measures = _normalize_background(measures, image, foreground_mask, center, radius_in_pixel, norm, median_only)

# Do the volume and surface measurement.
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
if not median_only:
volume, surface = _measure_volume_and_surface(mask, resolution)
measures["volume"] = volume
measures["surface"] = surface
return measures


def _regionprops_features(seg_id, table, image, segmentation, resolution):
bb = _get_bounding_box(table, seg_id, resolution, image.shape)
def _regionprops_features(seg_id, table, image, segmentation, resolution, foreground_mask=None):
bb, _ = _get_bounding_box_and_center(table, seg_id, resolution, image.shape)

local_image = image[bb]
local_segmentation = segmentation[bb]
Expand Down Expand Up @@ -106,21 +181,31 @@ def _regionprops_features(seg_id, table, image, segmentation, resolution):
FEATURE_FUNCTIONS = {
"default": _default_object_features,
"skimage": _regionprops_features,
"default_background_norm": partial(_default_object_features, background_radius=75, norm=np.divide),
"default_background_subtract": partial(_default_object_features, background_radius=75, norm=np.subtract),
}
"""The different feature functions that are supported in `compute_object_measures` and
that can be selected via the feature_set argument. Currently this supports:
- 'default': The default features which compute standard intensity statistics and volume + surface.
- 'skimage': The scikit image regionprops features.
- 'default_background_norm': The default features with background normalization.
- 'default_background_subtract': The default features with background subtraction.
For the background normalized measures, we compute the background intensity in a sphere with radius of 75 micrometer
around each object.
"""


# TODO integrate segmentation post-processing, see `_extend_sgns_simple` in `gfp_annotation.py`
def compute_object_measures_impl(
image: np.typing.ArrayLike,
segmentation: np.typing.ArrayLike,
n_threads: Optional[int] = None,
resolution: float = 0.38,
table: Optional[pd.DataFrame] = None,
feature_set: str = "default",
foreground_mask: Optional[np.typing.ArrayLike] = None,
median_only: bool = False,
) -> pd.DataFrame:
"""Compute simple intensity and morphology measures for each segmented cell in a segmentation.
Expand All @@ -133,6 +218,8 @@ def compute_object_measures_impl(
resolution: The resolution / voxel size of the data.
table: The segmentation table. Will be computed on the fly if it is not given.
feature_set: The features to compute for each object. Refer to `FEATURE_FUNCTIONS` for details.
foreground_mask: An optional mask indicating the area to use for computing background correction values.
median_only: Whether to only compute the median intensity.
Returns:
The table with per object measurements.
Expand All @@ -147,12 +234,19 @@ def compute_object_measures_impl(
table=table,
image=image,
segmentation=segmentation,
resolution=resolution
resolution=resolution,
foreground_mask=foreground_mask,
median_only=median_only,
)

seg_ids = table.label_id.values
assert len(seg_ids) > 0, "The segmentation table is empty."
measure_function(seg_ids[0])
n_threads = mp.cpu_count() if n_threads is None else n_threads

# For debugging.
# measure_function(seg_ids[0])

with futures.ThreadPoolExecutor(n_threads) as pool:
measures = list(tqdm(
pool.map(measure_function, seg_ids), total=len(seg_ids), desc="Compute intensity measures"
Expand Down Expand Up @@ -206,14 +300,14 @@ def compute_object_measures(
table = None
elif s3_flag:
seg_table, fs = s3_utils.get_s3_path(segmentation_table_path)
with fs.open(seg_table, 'r') as f:
with fs.open(seg_table, "r") as f:
table = pd.read_csv(f, sep="\t")
else:
table = pd.read_csv(segmentation_table_path, sep="\t")

# filter table with largest component
if len(component_list) != 0 and "component_labels" in table.columns:
table = table[table['component_labels'].isin(component_list)]
table = table[table["component_labels"].isin(component_list)]

# Then, open the volumes.
image = read_image_data(image_path, image_key)
Expand Down
43 changes: 39 additions & 4 deletions scripts/intensity_annotation/gfp_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, statistics, default_stat, bins: int = 32, parent=None):
self.canvas = FigureCanvasQTAgg(self.fig)

# We exclude the label id and the volume / surface measurements.
self.stat_names = statistics.columns[1:-2]
self.stat_names = statistics.columns[1:-2] if len(statistics.columns) > 2 else statistics.columns[1:]
self.param_choices = self.stat_names

self.param_box = QComboBox()
Expand Down Expand Up @@ -110,7 +110,30 @@ def _extend_sgns_simple(gfp, sgns, dilation):
return sgns_extended


def gfp_annotation(prefix, default_stat="median"):
def _create_mask(sgns_extended, gfp):
from skimage.transform import downscale_local_mean, resize

gfp_averaged = downscale_local_mean(gfp, factors=(16, 16, 16))
# The 35th percentile seems to be a decent approximation for the background subtraction.
threshold = np.percentile(gfp_averaged, 35)
mask = gfp_averaged > threshold
mask = resize(mask, sgns_extended.shape, order=0, anti_aliasing=False, preserve_range=True).astype(bool)
mask[sgns_extended != 0] = 0

# v = napari.Viewer()
# v.add_image(gfp)
# v.add_image(gfp_averaged, scale=(16, 16, 16))
# v.add_labels(mask)
# # v.add_labels(mask, scale=(16, 16, 16))
# v.add_labels(sgns_extended)
# napari.run()

return mask


def gfp_annotation(prefix, default_stat="median", background_norm=None):
assert background_norm in (None, "division", "subtraction")

gfp = imageio.imread(f"{prefix}_GFP_resized.tif")
sgns = imageio.imread(f"{prefix}_SGN_resized_v2.tif")
pv = imageio.imread(f"{prefix}_PV_resized.tif")
Expand All @@ -125,7 +148,16 @@ def gfp_annotation(prefix, default_stat="median"):
sgns_extended = _extend_sgns_simple(gfp, sgns, dilation=4)

# Compute the intensity statistics.
statistics = compute_object_measures_impl(gfp, sgns_extended)
if background_norm is None:
mask = None
feature_set = "default"
else:
mask = _create_mask(sgns_extended, gfp)
assert mask.shape == sgns_extended.shape
feature_set = "default_background_norm" if background_norm == "division" else "default_background_subtract"
statistics = compute_object_measures_impl(
gfp, sgns_extended, feature_set=feature_set, foreground_mask=mask, median_only=True
)

# Open the napari viewer.
v = napari.Viewer()
Expand All @@ -135,6 +167,8 @@ def gfp_annotation(prefix, default_stat="median"):
v.add_image(pv, visible=False, name="PV")
v.add_labels(sgns, visible=False, name="SGNs")
v.add_labels(sgns_extended, name="SGNs-extended")
if mask is not None:
v.add_labels(mask, name="mask-for-background", visible=False)

# Add additional layers for intensity coloring and classification
# data_numerical = np.zeros(gfp.shape, dtype="float32")
Expand Down Expand Up @@ -212,9 +246,10 @@ def threshold_widget(viewer: napari.Viewer, threshold: float = (max_val - min_va
def main():
parser = argparse.ArgumentParser()
parser.add_argument("prefix")
parser.add_argument("-b", "--background_norm")
args = parser.parse_args()

gfp_annotation(args.prefix)
gfp_annotation(args.prefix, background_norm=args.background_norm)


if __name__ == "__main__":
Expand Down
Loading