55# LICENSE file in the root directory of this source tree.
66
77from abc import ABC , abstractmethod
8- from dataclasses import dataclass
98from types import ModuleType
109from typing import Optional , Sequence , Tuple
1110
1211import torch
1312from torch import nn
1413
1514
16- @dataclass
1715class DecoderTransform (ABC ):
1816 """Base class for all decoder transforms.
1917
@@ -91,7 +89,6 @@ def import_torchvision_transforms_v2() -> ModuleType:
9189 return v2
9290
9391
94- @dataclass
9592class Resize (DecoderTransform ):
9693 """Resize the decoded frame to a given size.
9794
@@ -103,18 +100,20 @@ class Resize(DecoderTransform):
103100 the form (height, width).
104101 """
105102
106- size : Sequence [int ]
103+ def __init__ (self , size : Sequence [int ]):
104+ if len (size ) != 2 :
105+ raise ValueError (
106+ "Resize transform must have a (height, width) "
107+ f"pair for the size, got { size } ."
108+ )
109+ self .size = size
107110
108111 def _make_transform_spec (
109112 self , input_dims : Tuple [Optional [int ], Optional [int ]]
110113 ) -> str :
111- # TODO: establish this invariant in the constructor during refactor
112- assert len (self .size ) == 2
113114 return f"resize, { self .size [0 ]} , { self .size [1 ]} "
114115
115116 def _get_output_dims (self ) -> Optional [Tuple [Optional [int ], Optional [int ]]]:
116- # TODO: establish this invariant in the constructor during refactor
117- assert len (self .size ) == 2
118117 return (self .size [0 ], self .size [1 ])
119118
120119 @classmethod
@@ -141,7 +140,6 @@ def _from_torchvision(cls, tv_resize: nn.Module):
141140 return cls (size = tv_resize .size )
142141
143142
144- @dataclass
145143class RandomCrop (DecoderTransform ):
146144 """Crop the decoded frame to a given size at a random location in the frame.
147145
@@ -158,17 +156,17 @@ class RandomCrop(DecoderTransform):
158156 the form (height, width).
159157 """
160158
161- size : Sequence [int ]
159+ def __init__ (self , size : Sequence [int ]):
160+ if len (size ) != 2 :
161+ raise ValueError (
162+ "RandomCrop transform must have a (height, width) "
163+ f"pair for the size, got { size } ."
164+ )
165+ self .size = size
162166
163167 def _make_transform_spec (
164168 self , input_dims : Tuple [Optional [int ], Optional [int ]]
165169 ) -> str :
166- if len (self .size ) != 2 :
167- raise ValueError (
168- f"RandomCrop's size must be a sequence of length 2, got { self .size } . "
169- "This should never happen, please report a bug."
170- )
171-
172170 height , width = input_dims
173171 if height is None :
174172 raise ValueError (
@@ -196,8 +194,6 @@ def _make_transform_spec(
196194 return f"crop, { self .size [0 ]} , { self .size [1 ]} , { left } , { top } "
197195
198196 def _get_output_dims (self ) -> Optional [Tuple [Optional [int ], Optional [int ]]]:
199- # TODO: establish this invariant in the constructor during refactor
200- assert len (self .size ) == 2
201197 return (self .size [0 ], self .size [1 ])
202198
203199 @classmethod
0 commit comments