Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions monai/apps/deepedit/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.data import MetaTensor
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.utils import min_version, optional_import
from monai.utils import deprecated, min_version, optional_import

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)

Expand Down Expand Up @@ -84,18 +84,44 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
return d


class NormalizeLabelsInDatasetd(MapTransform):
class RemapLabelsToSequentiald(MapTransform):
"""
Remap label values from a dataset-specific schema to sequential indices (0, 1, 2, 3, ...).

This transform takes labels with arbitrary values defined in a label dictionary and remaps them
to a sequential range starting from 1 (with background always set to 0). This is useful for
standardizing labels across different datasets or ensuring labels are in a contiguous range.

The output label indices are assigned in alphabetical order by label name to ensure
deterministic behavior regardless of input dictionary ordering.

Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
label_names: Dictionary mapping label names to their current values in the dataset.
For example: {"spleen": 1, "liver": 6, "background": 0}
Will be remapped to: {"background": 0, "liver": 1, "spleen": 2}
(alphabetically sorted, excluding background)
allow_missing_keys: If True, missing keys in the data dictionary will not raise an error

Example:
>>> transform = RemapLabelsToSequentiald(
... keys="label",
... label_names={"liver": 6, "spleen": 1, "background": 0}
... )
>>> # Input label has values [0, 1, 6]
>>> # Output label will have values [0, 1, 2] (background=0, liver=1, spleen=2)
>>> # And updates d["label_names"] to {"background": 0, "liver": 1, "spleen": 2}

Note:
- Background label (if present) is always mapped to 0
- Non-background labels are mapped to sequential indices 1, 2, 3, ... in alphabetical order
- Undefined labels (not in label_names) will be set to 0 (background)
- The transform updates the data dictionary with a new "label_names" key containing the remapped values
"""

def __init__(
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
):
"""
Normalize label values according to label names dictionary

Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
label_names: all label names
"""
super().__init__(keys, allow_missing_keys)

self.label_names = label_names or {}
Expand All @@ -106,13 +132,18 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
# Dictionary containing new label numbers
new_label_names = {}
label = np.zeros(d[key].shape)
# Making sure the range values and number of labels are the same
for idx, (key_label, val_label) in enumerate(self.label_names.items(), start=1):
if key_label != "background":
new_label_names[key_label] = idx
label[d[key] == val_label] = idx
if key_label == "background":
new_label_names["background"] = 0

# Sort label names to ensure deterministic ordering (exclude background)
sorted_labels = sorted([(k, v) for k, v in self.label_names.items() if k != "background"])

# Always set background to 0 first
if "background" in self.label_names:
new_label_names["background"] = 0

# Assign sequential indices to sorted non-background labels
for idx, (key_label, val_label) in enumerate(sorted_labels, start=1):
new_label_names[key_label] = idx
label[d[key] == val_label] = idx

d["label_names"] = new_label_names
if isinstance(d[key], MetaTensor):
Expand All @@ -122,6 +153,20 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.nda
return d


@deprecated(since="1.6", removed="1.8", msg_suffix="Use `RemapLabelsToSequentiald` instead.")
class NormalizeLabelsInDatasetd(RemapLabelsToSequentiald):
"""
.. deprecated:: 1.6.0
`NormalizeLabelsInDatasetd` is deprecated and will be removed in version 1.8.0.
Use :class:`RemapLabelsToSequentiald` instead.

This class is maintained for backward compatibility. Please use RemapLabelsToSequentiald
which better describes the transform's functionality.
"""

pass


class SingleLabelSelectiond(MapTransform):

def __init__(
Expand Down
95 changes: 95 additions & 0 deletions tests/apps/deepedit/test_deepedit_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FindAllValidSlicesMissingLabelsd,
FindDiscrepancyRegionsDeepEditd,
NormalizeLabelsInDatasetd,
RemapLabelsToSequentiald,
ResizeGuidanceMultipleLabelDeepEditd,
SingleLabelSelectiond,
SplitPredsLabeld,
Expand Down Expand Up @@ -282,6 +283,100 @@ def test_correct_results(self, arguments, input_data, expected_result):
result = add_fn(input_data)
self.assertEqual(len(np.unique(result["label"])), expected_result)

def test_ordering_determinism(self):
"""Test that different input ordering produces the same output (alphabetical)"""
# Create a label array with different label values
label = np.array([[[0, 1, 6, 3]]]) # background=0, spleen=1, liver=6, kidney=3

# Test case 1: liver first, then kidney, then spleen
data1 = {"label": label.copy()}
transform1 = RemapLabelsToSequentiald(
keys="label", label_names={"liver": 6, "kidney": 3, "spleen": 1, "background": 0}
)
result1 = transform1(data1)

# Test case 2: spleen first, then kidney, then liver (different order)
data2 = {"label": label.copy()}
transform2 = RemapLabelsToSequentiald(
keys="label", label_names={"spleen": 1, "kidney": 3, "liver": 6, "background": 0}
)
result2 = transform2(data2)

# Both should produce the same output (alphabetically sorted)
# Expected mapping: background=0, kidney=1, liver=2, spleen=3
np.testing.assert_array_equal(result1["label"], result2["label"])

# Verify the actual mapping is alphabetical
expected_output = np.array([[[0, 3, 2, 1]]]) # kidney=1, liver=2, spleen=3, background=0
np.testing.assert_array_equal(result1["label"], expected_output)

# Verify label_names is correct
self.assertEqual(result1["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})
self.assertEqual(result2["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})

def test_multiple_labels(self):
"""Test with multiple non-background labels"""
label = np.array([[[0, 1, 2, 5]]]) # background, spleen, kidney, liver
data = {"label": label.copy()}
transform = RemapLabelsToSequentiald(
keys="label", label_names={"spleen": 1, "kidney": 2, "liver": 5, "background": 0}
)
result = transform(data)

# Expected: background=0, kidney=1, liver=2, spleen=3 (alphabetical)
expected = np.array([[[0, 3, 1, 2]]])
np.testing.assert_array_equal(result["label"], expected)
self.assertEqual(result["label_names"], {"background": 0, "kidney": 1, "liver": 2, "spleen": 3})

def test_deprecated_name_warning(self):
"""Test that NormalizeLabelsInDatasetd is properly deprecated.

The deprecation warning only triggers when MONAI version >= 1.6 (since="1.6").
This test verifies:
1. The actual NormalizeLabelsInDatasetd class is marked as deprecated in docstring
2. The class is a subclass of RemapLabelsToSequentiald
3. The deprecation mechanism works correctly (tested via version_val simulation)
4. The actual class functions correctly
"""
import warnings

from monai.utils import deprecated

# Verify NormalizeLabelsInDatasetd docstring indicates deprecation
self.assertIn("deprecated", NormalizeLabelsInDatasetd.__doc__.lower())
self.assertIn("RemapLabelsToSequentiald", NormalizeLabelsInDatasetd.__doc__)

# Verify NormalizeLabelsInDatasetd is a subclass of RemapLabelsToSequentiald
self.assertTrue(issubclass(NormalizeLabelsInDatasetd, RemapLabelsToSequentiald))

# Test the deprecation mechanism using version_val to simulate version 1.6
# This verifies the @deprecated decorator behavior that NormalizeLabelsInDatasetd uses
@deprecated(
since="1.6",
removed="1.8",
msg_suffix="Use `RemapLabelsToSequentiald` instead.",
version_val="1.6", # Simulate version 1.6 to trigger warning
)
class DeprecatedNormalizeLabels(RemapLabelsToSequentiald):
pass

data = {"label": np.array([[[0, 1]]])}

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
transform = DeprecatedNormalizeLabels(keys="label", label_names={"spleen": 1, "background": 0})
_ = transform(data)

# Check that a deprecation warning was raised
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[0].category, FutureWarning))
self.assertIn("RemapLabelsToSequentiald", str(w[0].message))

# Verify the actual NormalizeLabelsInDatasetd class works correctly
transform_actual = NormalizeLabelsInDatasetd(keys="label", label_names={"spleen": 1, "background": 0})
result = transform_actual({"label": np.array([[[0, 1]]])})
self.assertIn("label", result)


class TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase):

Expand Down
Loading