@@ -41,7 +41,9 @@ class DecoderTransform(ABC):
4141 def _make_transform_spec (self ) -> str :
4242 pass
4343
44- def _get_output_dims (self , input_dims : Tuple [int , int ]) -> Tuple [int , int ]:
44+ def _get_output_dims (
45+ self , input_dims : Tuple [Optional [int ], Optional [int ]]
46+ ) -> Tuple [Optional [int ], Optional [int ]]:
4547 return input_dims
4648
4749
@@ -70,34 +72,39 @@ class Resize(DecoderTransform):
7072 size : Sequence [int ]
7173
7274 def _make_transform_spec (self ) -> str :
75+ # TODO: establish this invariant in the constructor during refactor
7376 assert len (self .size ) == 2
7477 return f"resize, { self .size [0 ]} , { self .size [1 ]} "
7578
76- def _get_output_dims (self , input_dims : Tuple [int , int ]) -> Tuple [int , int ]:
77- return (* self .size ,)
79+ def _get_output_dims (
80+ self , input_dims : Tuple [Optional [int ], Optional [int ]]
81+ ) -> Tuple [Optional [int ], Optional [int ]]:
82+ # TODO: establish this invariant in the constructor during refactor
83+ assert len (self .size ) == 2
84+ return (self .size [0 ], self .size [1 ])
7885
7986 @classmethod
80- def _from_torchvision (cls , resize_tv : nn .Module ):
87+ def _from_torchvision (cls , tv_resize : nn .Module ):
8188 v2 = import_torchvision_transforms_v2 ()
8289
83- assert isinstance (resize_tv , v2 .Resize )
90+ assert isinstance (tv_resize , v2 .Resize )
8491
85- if resize_tv .interpolation is not v2 .InterpolationMode .BILINEAR :
92+ if tv_resize .interpolation is not v2 .InterpolationMode .BILINEAR :
8693 raise ValueError (
8794 "TorchVision Resize transform must use bilinear interpolation."
8895 )
89- if resize_tv .antialias is False :
96+ if tv_resize .antialias is False :
9097 raise ValueError (
9198 "TorchVision Resize transform must have antialias enabled."
9299 )
93- if resize_tv .size is None :
100+ if tv_resize .size is None :
94101 raise ValueError ("TorchVision Resize transform must have a size specified." )
95- if len (resize_tv .size ) != 2 :
102+ if len (tv_resize .size ) != 2 :
96103 raise ValueError (
97104 "TorchVision Resize transform must have a (height, width) "
98- f"pair for the size, got { resize_tv .size } ."
105+ f"pair for the size, got { tv_resize .size } ."
99106 )
100- return cls (size = resize_tv .size )
107+ return cls (size = tv_resize .size )
101108
102109
103110@dataclass
@@ -140,52 +147,92 @@ def _make_transform_spec(self) -> str:
140147 )
141148 if self ._input_dims [0 ] < self .size [0 ] or self ._input_dims [1 ] < self .size [1 ]:
142149 raise ValueError (
143- f"Input dimensions { input_dims } are smaller than the crop size { self .size } ."
150+ f"Input dimensions { self . _input_dims } are smaller than the crop size { self .size } ."
144151 )
145152
146153 # Note: This logic must match the logic in
147154 # torchvision.transforms.v2.RandomCrop.make_params(). Given
148155 # the same seed, they should get the same result. This is an
149156 # API guarantee with our users.
150- self ._top = torch . randint (
151- 0 , self ._input_dims [0 ] - self .size [0 ] + 1 , size = ()
157+ self ._top = int (
158+ torch . randint ( 0 , self ._input_dims [0 ] - self .size [0 ] + 1 , size = ()). item ()
152159 )
153- self ._left = torch . randint (
154- 0 , self ._input_dims [1 ] - self .size [1 ] + 1 , size = ()
160+ self ._left = int (
161+ torch . randint ( 0 , self ._input_dims [1 ] - self .size [1 ] + 1 , size = ()). item ()
155162 )
156163
157164 return f"crop, { self .size [0 ]} , { self .size [1 ]} , { self ._left } , { self ._top } "
158165
159- def _get_output_dims (self , input_dims : Tuple [int , int ]) -> Tuple [int , int ]:
160- self ._input_dims = input_dims
161- return self .size
166+ def _get_output_dims (
167+ self , input_dims : Tuple [Optional [int ], Optional [int ]]
168+ ) -> Tuple [Optional [int ], Optional [int ]]:
169+ # TODO: establish this invariant in the constructor during refactor
170+ assert len (self .size ) == 2
171+
172+ height , width = input_dims
173+ if height is None :
174+ raise ValueError (
175+ "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
176+ )
177+ if width is None :
178+ raise ValueError (
179+ "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
180+ )
181+
182+ self ._input_dims = (height , width )
183+ return (self .size [0 ], self .size [1 ])
162184
163185 @classmethod
164- def _from_torchvision (cls , random_crop_tv : nn .Module , input_dims : Tuple [int , int ]):
186+ def _from_torchvision (
187+ cls ,
188+ tv_random_crop : nn .Module ,
189+ input_dims : Tuple [Optional [int ], Optional [int ]],
190+ ):
165191 v2 = import_torchvision_transforms_v2 ()
166192
167- assert isinstance (random_crop_tv , v2 .RandomCrop )
193+ assert isinstance (tv_random_crop , v2 .RandomCrop )
168194
169- if random_crop_tv .padding is not None :
195+ if tv_random_crop .padding is not None :
170196 raise ValueError (
171197 "TorchVision RandomCrop transform must not specify padding."
172198 )
173- if random_crop_tv .pad_if_needed is True :
199+
200+ if tv_random_crop .pad_if_needed is True :
174201 raise ValueError (
175202 "TorchVision RandomCrop transform must not specify pad_if_needed."
176203 )
177- if random_crop_tv .fill != 0 :
204+
205+ if tv_random_crop .fill != 0 :
178206 raise ValueError ("TorchVision RandomCrop fill must be 0." )
179- if random_crop_tv .padding_mode != "constant" :
207+
208+ if tv_random_crop .padding_mode != "constant" :
180209 raise ValueError ("TorchVision RandomCrop padding_mode must be constant." )
181- if len (random_crop_tv .size ) != 2 :
210+
211+ if len (tv_random_crop .size ) != 2 :
182212 raise ValueError (
183213 "TorchVision RandcomCrop transform must have a (height, width) "
184- f"pair for the size, got { random_crop_tv .size } ."
214+ f"pair for the size, got { tv_random_crop .size } ."
215+ )
216+
217+ height , width = input_dims
218+ if height is None :
219+ raise ValueError (
220+ "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known."
221+ )
222+ if width is None :
223+ raise ValueError (
224+ "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known."
185225 )
186- params = random_crop_tv .make_params (
187- # TODO: deal with NCHW versus NHWC; video decoder knows
188- torch .empty (size = (3 , * input_dims ), dtype = torch .uint8 )
226+
227+ # Note that TorchVision v2 transforms only accept NCHW tensors.
228+ params = tv_random_crop .make_params (
229+ torch .empty (size = (3 , height , width ), dtype = torch .uint8 )
189230 )
190- assert random_crop_tv .size == (params ["height" ], params ["width" ])
191- return cls (size = random_crop_tv .size , _top = params ["top" ], _left = params ["left" ])
231+
232+ if tv_random_crop .size != (params ["height" ], params ["width" ]):
233+ raise ValueError (
234+ f"TorchVision RandomCrop's provided size, { tv_random_crop .size } "
235+ f"must match the computed size, { params ['height' ], params ['width' ]} ."
236+ )
237+
238+ return cls (size = tv_random_crop .size , _top = params ["top" ], _left = params ["left" ])
0 commit comments