55# LICENSE file in the root directory of this source tree.
66
77from abc import ABC , abstractmethod
8- from dataclasses import dataclass
98from types import ModuleType
109from typing import Optional , Sequence , Tuple
1110
1211import torch
1312from torch import nn
1413
1514
16- @dataclass
1715class DecoderTransform (ABC ):
1816 """Base class for all decoder transforms.
1917
@@ -59,7 +57,6 @@ def import_torchvision_transforms_v2() -> ModuleType:
5957 return v2
6058
6159
62- @dataclass
6360class Resize (DecoderTransform ):
6461 """Resize the decoded frame to a given size.
6562
@@ -71,20 +68,22 @@ class Resize(DecoderTransform):
7168 the form (height, width).
7269 """
7370
74- size : Sequence [int ]
71+ def __init__ (self , size : Sequence [int ]):
72+ if len (size ) != 2 :
73+ raise ValueError (
74+ "Resize transform must have a (height, width) "
75+ f"pair for the size, got { size } ."
76+ )
77+ self .size = size
7578
7679 def _make_transform_spec (
7780 self , input_dims : Tuple [Optional [int ], Optional [int ]]
7881 ) -> str :
79- # TODO: establish this invariant in the constructor during refactor
80- assert len (self .size ) == 2
8182 return f"resize, { self .size [0 ]} , { self .size [1 ]} "
8283
8384 def _calculate_output_dims (
8485 self , input_dims : Tuple [Optional [int ], Optional [int ]]
8586 ) -> Tuple [Optional [int ], Optional [int ]]:
86- # TODO: establish this invariant in the constructor during refactor
87- assert len (self .size ) == 2
8887 return (self .size [0 ], self .size [1 ])
8988
9089 @classmethod
@@ -111,7 +110,6 @@ def _from_torchvision(cls, tv_resize: nn.Module):
111110 return cls (size = tv_resize .size )
112111
113112
114- @dataclass
115113class RandomCrop (DecoderTransform ):
116114 """Crop the decoded frame to a given size at a random location in the frame.
117115
@@ -128,22 +126,17 @@ class RandomCrop(DecoderTransform):
128126 the form (height, width).
129127 """
130128
131- size : Sequence [int ]
132-
133- # Note that these values are never read by this object or the decoder. We
134- # record them for testing purposes only.
135- _top : Optional [int ] = None
136- _left : Optional [int ] = None
129+ def __init__ (self , size : Sequence [int ]):
130+ if len (size ) != 2 :
131+ raise ValueError (
132+ "RandomCrop transform must have a (height, width) "
133+ f"pair for the size, got { size } ."
134+ )
135+ self .size = size
137136
138137 def _make_transform_spec (
139138 self , input_dims : Tuple [Optional [int ], Optional [int ]]
140139 ) -> str :
141- if len (self .size ) != 2 :
142- raise ValueError (
143- f"RandomCrop's size must be a sequence of length 2, got { self .size } . "
144- "This should never happen, please report a bug."
145- )
146-
147140 height , width = input_dims
148141 if height is None :
149142 raise ValueError (
@@ -176,9 +169,6 @@ def _make_transform_spec(
176169 def _calculate_output_dims (
177170 self , input_dims : Tuple [Optional [int ], Optional [int ]]
178171 ) -> Tuple [Optional [int ], Optional [int ]]:
179- # TODO: establish this invariant in the constructor during refactor
180- assert len (self .size ) == 2
181-
182172 height , width = input_dims
183173 if height is None :
184174 raise ValueError (
0 commit comments