diff --git a/sample-apps/radiology/lib/infers/deepedit.py b/sample-apps/radiology/lib/infers/deepedit.py index afc755c98..679cd9982 100644 --- a/sample-apps/radiology/lib/infers/deepedit.py +++ b/sample-apps/radiology/lib/infers/deepedit.py @@ -11,7 +11,8 @@ import logging from typing import Callable, Sequence, Union -from lib.transforms.transforms import GetCentroidsd +from lib.transforms.transforms import GetCentroidsd, OrientationGuidanceMultipleLabelDeepEditd + from monai.apps.deepedit.transforms import ( AddGuidanceFromPointsDeepEditd, AddGuidanceSignalDeepEditd, @@ -88,6 +89,7 @@ def pre_transforms(self, data=None): if self.type == InferType.DEEPEDIT: t.extend( [ + OrientationGuidanceMultipleLabelDeepEditd(ref_image="image", label_names=self.labels), AddGuidanceFromPointsDeepEditd(ref_image="image", guidance="guidance", label_names=self.labels), Resized(keys="image", spatial_size=self.spatial_size, mode="area"), ResizeGuidanceMultipleLabelDeepEditd(guidance="guidance", ref_image="image"), diff --git a/sample-apps/radiology/lib/transforms/transforms.py b/sample-apps/radiology/lib/transforms/transforms.py index ac528dec7..edd44b96a 100644 --- a/sample-apps/radiology/lib/transforms/transforms.py +++ b/sample-apps/radiology/lib/transforms/transforms.py @@ -14,6 +14,7 @@ import numpy as np import torch +from einops import rearrange from monai.config import KeysCollection, NdarrayOrTensor from monai.data import MetaTensor from monai.networks.layers import GaussianFilter @@ -511,6 +512,39 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class OrientationGuidanceMultipleLabelDeepEditd(Transform): + def __init__(self, ref_image: str, label_names=None): + """ + Convert the guidance to the RAS orientation + """ + self.ref_image = ref_image + self.label_names = label_names + + def transform_points(self, point, affine): + """transform point to the coordinates of the transformed image + point: numpy array [bs, N, 3] + """ + bs, N = point.shape[:2] + point = np.concatenate((point, np.ones((bs, N, 1))), axis=-1) + point = rearrange(point, "b n d -> d (b n)") + point = affine @ point + point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] + return point + + def __call__(self, data): + d: Dict = dict(data) + for key_label in self.label_names.keys(): + points = d.get(key_label, []) + if len(points) < 1: + continue + reoriented_points = self.transform_points( + np.array(points)[None], + np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"], + ) + d[key_label] = reoriented_points[0] + return d + + def get_guidance_tensor_for_key_label(data, key_label, device) -> torch.Tensor: """Makes sure the guidance is in a tensor format.""" tmp_gui = data.get(key_label, torch.tensor([], dtype=torch.int32, device=device))