Skip to content

Commit a92d5f0

Browse files
committed
Refactor DecoderTransforms to normal classes
1 parent 817b1f8 commit a92d5f0

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from dataclasses import dataclass
98
from types import ModuleType
109
from typing import Optional, Sequence, Tuple
1110

1211
import torch
1312
from torch import nn
1413

1514

16-
@dataclass
1715
class DecoderTransform(ABC):
1816
"""Base class for all decoder transforms.
1917
@@ -59,7 +57,6 @@ def import_torchvision_transforms_v2() -> ModuleType:
5957
return v2
6058

6159

62-
@dataclass
6360
class Resize(DecoderTransform):
6461
"""Resize the decoded frame to a given size.
6562
@@ -71,20 +68,22 @@ class Resize(DecoderTransform):
7168
the form (height, width).
7269
"""
7370

74-
size: Sequence[int]
71+
def __init__(self, size: Sequence[int]):
72+
if len(size) != 2:
73+
raise ValueError(
74+
"Resize transform must have a (height, width) "
75+
f"pair for the size, got {size}."
76+
)
77+
self.size = size
7578

7679
def _make_transform_spec(
7780
self, input_dims: Tuple[Optional[int], Optional[int]]
7881
) -> str:
79-
# TODO: establish this invariant in the constructor during refactor
80-
assert len(self.size) == 2
8182
return f"resize, {self.size[0]}, {self.size[1]}"
8283

8384
def _calculate_output_dims(
8485
self, input_dims: Tuple[Optional[int], Optional[int]]
8586
) -> Tuple[Optional[int], Optional[int]]:
86-
# TODO: establish this invariant in the constructor during refactor
87-
assert len(self.size) == 2
8887
return (self.size[0], self.size[1])
8988

9089
@classmethod
@@ -111,7 +110,6 @@ def _from_torchvision(cls, tv_resize: nn.Module):
111110
return cls(size=tv_resize.size)
112111

113112

114-
@dataclass
115113
class RandomCrop(DecoderTransform):
116114
"""Crop the decoded frame to a given size at a random location in the frame.
117115
@@ -128,22 +126,17 @@ class RandomCrop(DecoderTransform):
128126
the form (height, width).
129127
"""
130128

131-
size: Sequence[int]
132-
133-
# Note that these values are never read by this object or the decoder. We
134-
# record them for testing purposes only.
135-
_top: Optional[int] = None
136-
_left: Optional[int] = None
129+
def __init__(self, size: Sequence[int]):
130+
if len(size) != 2:
131+
raise ValueError(
132+
"RandomCrop transform must have a (height, width) "
133+
f"pair for the size, got {size}."
134+
)
135+
self.size = size
137136

138137
def _make_transform_spec(
139138
self, input_dims: Tuple[Optional[int], Optional[int]]
140139
) -> str:
141-
if len(self.size) != 2:
142-
raise ValueError(
143-
f"RandomCrop's size must be a sequence of length 2, got {self.size}. "
144-
"This should never happen, please report a bug."
145-
)
146-
147140
height, width = input_dims
148141
if height is None:
149142
raise ValueError(
@@ -176,9 +169,6 @@ def _make_transform_spec(
176169
def _calculate_output_dims(
177170
self, input_dims: Tuple[Optional[int], Optional[int]]
178171
) -> Tuple[Optional[int], Optional[int]]:
179-
# TODO: establish this invariant in the constructor during refactor
180-
assert len(self.size) == 2
181-
182172
height, width = input_dims
183173
if height is None:
184174
raise ValueError(

test/test_transform_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def test_resize_fails(self):
145145
):
146146
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])
147147

148+
with pytest.raises(
149+
ValueError,
150+
match=r"must have a \(height, width\) pair for the size",
151+
):
152+
VideoDecoder(
153+
NASA_VIDEO.path,
154+
transforms=[torchcodec.transforms.Resize(size=(100, 100, 100))],
155+
)
156+
148157
@pytest.mark.parametrize(
149158
"height_scaling_factor, width_scaling_factor",
150159
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),

0 commit comments

Comments
 (0)