diff --git a/batchgeneratorsv2/transforms/base/basic_transform.py b/batchgeneratorsv2/transforms/base/basic_transform.py index 1db856b..7cd1064 100644 --- a/batchgeneratorsv2/transforms/base/basic_transform.py +++ b/batchgeneratorsv2/transforms/base/basic_transform.py @@ -26,6 +26,9 @@ def apply(self, data_dict, **params): if data_dict.get('segmentation') is not None: data_dict['segmentation'] = self._apply_to_segmentation(data_dict['segmentation'], **params) + + if data_dict.get('dist_map') is not None: + data_dict['dist_map'] = self._apply_to_dist_map(data_dict['dist_map'], **params) if data_dict.get('keypoints') is not None: data_dict['keypoints'] = self._apply_to_keypoints(data_dict['keypoints'], **params) @@ -44,6 +47,9 @@ def _apply_to_regr_target(self, regression_target, **params) -> torch.Tensor: def _apply_to_segmentation(self, segmentation: torch.Tensor, **params) -> torch.Tensor: pass + def _apply_to_dist_map(self, dist_map: torch.Tensor, **params) -> torch.Tensor: + pass + def _apply_to_keypoints(self, keypoints, **params): pass @@ -74,4 +80,4 @@ def apply(self, data_dict: dict, **params) -> dict: if __name__ == '__main__': - pass \ No newline at end of file + pass