Skip to content

Commit 1b13e58

Browse files
authored
Refactor decoder transforms (#1081)
1 parent 38fa96c commit 1b13e58

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

docs/source/api_ref_transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL.
1414
:template: dataclass.rst
1515

1616
DecoderTransform
17+
RandomCrop
1718
Resize

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from dataclasses import dataclass
98
from types import ModuleType
109
from typing import Optional, Sequence, Tuple
1110

1211
import torch
1312
from torch import nn
1413

1514

16-
@dataclass
1715
class DecoderTransform(ABC):
1816
"""Base class for all decoder transforms.
1917
@@ -91,7 +89,6 @@ def import_torchvision_transforms_v2() -> ModuleType:
9189
return v2
9290

9391

94-
@dataclass
9592
class Resize(DecoderTransform):
9693
"""Resize the decoded frame to a given size.
9794
@@ -103,18 +100,20 @@ class Resize(DecoderTransform):
103100
the form (height, width).
104101
"""
105102

106-
size: Sequence[int]
103+
def __init__(self, size: Sequence[int]):
104+
if len(size) != 2:
105+
raise ValueError(
106+
"Resize transform must have a (height, width) "
107+
f"pair for the size, got {size}."
108+
)
109+
self.size = size
107110

108111
def _make_transform_spec(
109112
self, input_dims: Tuple[Optional[int], Optional[int]]
110113
) -> str:
111-
# TODO: establish this invariant in the constructor during refactor
112-
assert len(self.size) == 2
113114
return f"resize, {self.size[0]}, {self.size[1]}"
114115

115116
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
116-
# TODO: establish this invariant in the constructor during refactor
117-
assert len(self.size) == 2
118117
return (self.size[0], self.size[1])
119118

120119
@classmethod
@@ -141,7 +140,6 @@ def _from_torchvision(cls, tv_resize: nn.Module):
141140
return cls(size=tv_resize.size)
142141

143142

144-
@dataclass
145143
class RandomCrop(DecoderTransform):
146144
"""Crop the decoded frame to a given size at a random location in the frame.
147145
@@ -158,17 +156,17 @@ class RandomCrop(DecoderTransform):
158156
the form (height, width).
159157
"""
160158

161-
size: Sequence[int]
159+
def __init__(self, size: Sequence[int]):
160+
if len(size) != 2:
161+
raise ValueError(
162+
"RandomCrop transform must have a (height, width) "
163+
f"pair for the size, got {size}."
164+
)
165+
self.size = size
162166

163167
def _make_transform_spec(
164168
self, input_dims: Tuple[Optional[int], Optional[int]]
165169
) -> str:
166-
if len(self.size) != 2:
167-
raise ValueError(
168-
f"RandomCrop's size must be a sequence of length 2, got {self.size}. "
169-
"This should never happen, please report a bug."
170-
)
171-
172170
height, width = input_dims
173171
if height is None:
174172
raise ValueError(
@@ -196,8 +194,6 @@ def _make_transform_spec(
196194
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"
197195

198196
def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]:
199-
# TODO: establish this invariant in the constructor during refactor
200-
assert len(self.size) == 2
201197
return (self.size[0], self.size[1])
202198

203199
@classmethod

test/test_transform_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ def test_resize_fails(self):
145145
):
146146
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])
147147

148+
with pytest.raises(
149+
ValueError,
150+
match=r"must have a \(height, width\) pair for the size",
151+
):
152+
VideoDecoder(
153+
NASA_VIDEO.path,
154+
transforms=[torchcodec.transforms.Resize(size=(100, 100, 100))],
155+
)
156+
148157
@pytest.mark.parametrize(
149158
"height_scaling_factor, width_scaling_factor",
150159
((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)),

0 commit comments

Comments
 (0)