Skip to content

Commit 2fed9c3

Browse files
committed
Moare refactor pleasze
1 parent f8844f4 commit 2fed9c3

File tree

3 files changed

+27
-34
lines changed

3 files changed

+27
-34
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -518,31 +518,26 @@ def _make_transform_specs(
518518
] = []
519519
curr_input_dims = input_dims
520520
for transform in transforms:
521-
if isinstance(transform, DecoderTransform):
522-
output_dims = transform._get_output_dims()
523-
converted_transforms.append((transform, curr_input_dims))
524-
else:
521+
if not isinstance(transform, DecoderTransform):
525522
if not tv_available:
526523
raise ValueError(
527524
f"The supplied transform, {transform}, is not a TorchCodec "
528525
" DecoderTransform. TorchCodec also accepts TorchVision "
529526
"v2 transforms, but TorchVision is not installed."
530527
)
531528
elif isinstance(transform, v2.Resize):
532-
tc_transform = Resize._from_torchvision(transform)
533-
output_dims = tc_transform._get_output_dims()
534-
converted_transforms.append((tc_transform, curr_input_dims))
529+
transform = Resize._from_torchvision(transform)
535530
elif isinstance(transform, v2.RandomCrop):
536-
tc_transform = RandomCrop._from_torchvision(transform)
537-
output_dims = tc_transform._get_output_dims()
538-
converted_transforms.append((tc_transform, curr_input_dims))
531+
transform = RandomCrop._from_torchvision(transform)
539532
else:
540533
raise ValueError(
541534
f"Unsupported transform: {transform}. Transforms must be "
542535
"either a TorchCodec DecoderTransform or a TorchVision "
543536
"v2 transform."
544537
)
545538

539+
converted_transforms.append((transform, curr_input_dims))
540+
output_dims = transform._get_output_dims()
546541
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
547542

548543
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,6 @@ class RandomCrop(DecoderTransform):
129129

130130
size: Sequence[int]
131131

132-
# Note that these values are never read by this object or the decoder. We
133-
# record them for testing purposes only.
134-
_top: Optional[int] = None
135-
_left: Optional[int] = None
136-
137132
def _make_transform_spec(
138133
self, input_dims: Tuple[Optional[int], Optional[int]]
139134
) -> str:
@@ -165,10 +160,7 @@ def _make_transform_spec(
165160
)
166161

167162
top = int(torch.randint(0, height - self.size[0] + 1, size=()).item())
168-
self._top = top
169-
170163
left = int(torch.randint(0, width - self.size[1] + 1, size=()).item())
171-
self._left = left
172164

173165
return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}"
174166

test/test_transform_ops.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,15 @@ def test_random_crop_torchvision(
163163

164164
# We want both kinds of RandomCrop objects to get arrive at the same
165165
# locations to crop, so we need to make sure they get the same random
166-
# seed.
166+
# seed. It's used in RandomCrop's _make_transform_spec() method, called
167+
# by the VideoDecoder.
167168
torch.manual_seed(seed)
168169
tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width))
169170
decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop])
170171

172+
# Resetting manual seed for when TorchCodec's RandomCrop, created from
173+
# the TorchVision RandomCrop, is used inside of the VideoDecoder. It
174+
# needs to match the call above.
171175
torch.manual_seed(seed)
172176
decoder_random_crop_tv = VideoDecoder(
173177
video.path,
@@ -193,14 +197,11 @@ def test_random_crop_torchvision(
193197
expected_shape = (video.get_num_color_channels(), height, width)
194198
assert frame_random_crop_tv.shape == expected_shape
195199

200+
# Resetting manual seed to make sure the invocation of the
201+
# TorchVision RandomCrop matches the two calls above.
202+
torch.manual_seed(seed)
196203
frame_full = decoder_full[frame_index]
197-
frame_tv = v2.functional.crop(
198-
frame_full,
199-
top=tc_random_crop._top,
200-
left=tc_random_crop._left,
201-
height=tc_random_crop.size[0],
202-
width=tc_random_crop.size[1],
203-
)
204+
frame_tv = v2.RandomCrop(size=(height, width))(frame_full)
204205
assert_frames_equal(frame_random_crop, frame_tv)
205206

206207
@pytest.mark.parametrize(
@@ -260,18 +261,23 @@ def test_crop_fails(self, error_message, params):
260261
@pytest.mark.parametrize("seed", [0, 314])
261262
def test_random_crop_reusable_objects(self, seed):
262263
torch.manual_seed(seed)
263-
random_crop = torchcodec.transforms.RandomCrop(size=(100, 100))
264+
random_crop = torchcodec.transforms.RandomCrop(size=(99, 99))
264265

265266
# Create a spec which causes us to calculate the random crop location.
266-
_ = random_crop._make_transform_spec((1000, 1000))
267-
first_top = random_crop._top
268-
first_left = random_crop._left
267+
first_spec = random_crop._make_transform_spec((888, 888))
269268

270269
# Create a spec again, which should calculate a different random crop
271-
# location.
272-
_ = random_crop._make_transform_spec((1000, 1000))
273-
assert first_top != random_crop._top
274-
assert first_left != random_crop._left
270+
# location. Despite having the same image size, the specs should be
271+
# different because the crop should be at a different location
272+
second_spec = random_crop._make_transform_spec((888, 888))
273+
assert first_spec != second_spec
274+
275+
# Create a spec again, but with a different image size. The specs should
276+
# obviously be different, but the original image size should not be in
277+
# the spec at all.
278+
third_spec = random_crop._make_transform_spec((777, 777))
279+
assert third_spec != first_spec
280+
assert "888" not in third_spec
275281

276282
@pytest.mark.parametrize(
277283
"resize, random_crop",

0 commit comments

Comments
 (0)