diff --git a/kauldron/data/tf/random_transforms.py b/kauldron/data/tf/random_transforms.py index d4ac5152..9fb2a8e0 100644 --- a/kauldron/data/tf/random_transforms.py +++ b/kauldron/data/tf/random_transforms.py @@ -36,6 +36,14 @@ class ElementWiseRandomTransform( ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class RandomMapTransform(grain.RandomMapTransform): + """Wraps RandomMapTransform with an adapter to remove Grain meta features.""" + + # Wrap `random_map` to remove the `grain.META_FEATURES` + random_map = transform_utils.wrap_map(grain.RandomMapTransform.random_map) + + @dataclasses.dataclass(kw_only=True, frozen=True, eq=True) class InceptionCrop(ElementWiseRandomTransform): """Makes inception-style image crop and optionally resizes afterwards.