@@ -91,3 +91,42 @@ def _from_torchvision(cls, resize_tv: nn.Module):
9191 f"pair for the size, got { resize_tv .size } ."
9292 )
9393 return cls (size = resize_tv .size )
94+
95+ @dataclass
96+ class RandomCrop (DecoderTransform ):
97+
98+ size : Sequence [int ]
99+ _top : Optional [int ] = None
100+ _left : Optional [int ] = None
101+
102+ def _make_transform_spec (self ) -> str :
103+ assert len (self .size ) == 2
104+ return f"crop, { self .size [0 ]} , { self .size [1 ]} , { _left } , { _top } "
105+
106+ @classmethod
107+ def _from_torchvision (cls , random_crop_tv : nn .Module ):
108+ v2 = import_torchvision_transforms_v2 ()
109+
110+ assert isinstance (random_crop_tv , v2 .RandomCrop )
111+
112+ if random_crop_tv .padding is not None :
113+ raise ValueError (
114+ "TorchVision RandomCrop transform must not specify padding."
115+ )
116+ if random_crop_tv .pad_if_needed is True :
117+ raise ValueError (
118+ "TorchVision RandomCrop transform must not specify pad_if_needed."
119+ )
120+ if random_crop_tv .fill != 0 :
121+ raise ValueError ("TorchVision RandomCrop must specify fill of 0." )
122+ if random_crop_tv .padding_mode != "constant" :
123+ raise ValueError (
124+ "TorchVision RandomCrop must specify padding_mode of constant."
125+ )
126+ if len (random_crop_tv .size ) != 2 :
127+ raise ValueError (
128+ "TorchVision RandcomCrop transform must have a (height, width) "
129+ f"pair for the size, got { random_crop_tv .size } ."
130+ )
131+ params = random_crop_tv .make_params ([])
132+ return cls (size = random_crop_tv .size )
0 commit comments