Skip to content

Commit 4e412b7

Browse files
authored
Refactor _make_transform_specs() location (#1097)
1 parent c3356e5 commit 4e412b7

File tree

2 files changed

+101
-100
lines changed

2 files changed

+101
-100
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22-
from torchcodec.transforms import CenterCrop, DecoderTransform, RandomCrop, Resize
22+
from torchcodec.transforms import DecoderTransform
23+
from torchcodec.transforms._decoder_transforms import _make_transform_specs
2324

2425

2526
class VideoDecoder:
@@ -451,104 +452,6 @@ def _get_and_validate_stream_metadata(
451452
)
452453

453454

454-
def _make_transform_specs(
455-
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
456-
input_dims: Tuple[Optional[int], Optional[int]],
457-
) -> str:
458-
"""Given a sequence of transforms, turn those into the specification string
459-
the core API expects.
460-
461-
Args:
462-
transforms: Optional sequence of transform objects. The objects can be
463-
one of two types:
464-
1. torchcodec.transforms.DecoderTransform
465-
2. torchvision.transforms.v2.Transform, but our type annotation
466-
only mentions its base, nn.Module. We don't want to take a
467-
hard dependency on TorchVision.
468-
input_dims: Optional (height, width) pair. Note that only some
469-
transforms need to know the dimensions. If the user provides
470-
transforms that don't need to know the dimensions, and that metadata
471-
is missing, everything should still work. That means we assert their
472-
existence as late as possible.
473-
474-
Returns:
475-
String of transforms in the format the core API expects: transform
476-
specifications separate by semicolons.
477-
"""
478-
if transforms is None:
479-
return ""
480-
481-
try:
482-
from torchvision.transforms import v2
483-
484-
tv_available = True
485-
except ImportError:
486-
tv_available = False
487-
488-
# The following loop accomplishes two tasks:
489-
#
490-
# 1. Converts the transform to a DecoderTransform, if necessary. We
491-
# accept TorchVision transform objects and they must be converted
492-
# to their matching DecoderTransform.
493-
# 2. Calculates what the input dimensions are to each transform.
494-
#
495-
# The order in our transforms list is semantically meaningful, as we
496-
# actually have a pipeline where the output of one transform is the input to
497-
# the next. For example, if we have the transforms list [A, B, C, D], then
498-
# we should understand that as:
499-
#
500-
# A -> B -> C -> D
501-
#
502-
# Where the frame produced by A is the input to B, the frame produced by B
503-
# is the input to C, etc. This particularly matters for frame dimensions.
504-
# Transforms can both:
505-
#
506-
# 1. Produce frames with arbitrary dimensions.
507-
# 2. Rely on their input frame's dimensions to calculate ahead-of-time
508-
# what their runtime behavior will be.
509-
#
510-
# The consequence of the above facts is that we need to statically track
511-
# frame dimensions in the pipeline while we pre-process it. The input
512-
# frame's dimensions to A, our first transform, is always what we know from
513-
# our metadata. For each transform, we always calculate its output
514-
# dimensions from its input dimensions. We store these with the converted
515-
# transform, to be all used together when we generate the specs.
516-
converted_transforms: list[
517-
Tuple[
518-
DecoderTransform,
519-
# A (height, width) pair where the values may be missing.
520-
Tuple[Optional[int], Optional[int]],
521-
]
522-
] = []
523-
curr_input_dims = input_dims
524-
for transform in transforms:
525-
if not isinstance(transform, DecoderTransform):
526-
if not tv_available:
527-
raise ValueError(
528-
f"The supplied transform, {transform}, is not a TorchCodec "
529-
" DecoderTransform. TorchCodec also accepts TorchVision "
530-
"v2 transforms, but TorchVision is not installed."
531-
)
532-
elif isinstance(transform, v2.Resize):
533-
transform = Resize._from_torchvision(transform)
534-
elif isinstance(transform, v2.CenterCrop):
535-
transform = CenterCrop._from_torchvision(transform)
536-
elif isinstance(transform, v2.RandomCrop):
537-
transform = RandomCrop._from_torchvision(transform)
538-
else:
539-
raise ValueError(
540-
f"Unsupported transform: {transform}. Transforms must be "
541-
"either a TorchCodec DecoderTransform or a TorchVision "
542-
"v2 transform."
543-
)
544-
545-
converted_transforms.append((transform, curr_input_dims))
546-
output_dims = transform._get_output_dims()
547-
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
548-
549-
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])
550-
551-
552455
def _read_custom_frame_mappings(
553456
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
554457
) -> tuple[Tensor, Tensor, Tensor]:

src/torchcodec/transforms/_decoder_transforms.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from abc import ABC, abstractmethod
88
from types import ModuleType
9-
from typing import Optional, Sequence, Tuple
9+
from typing import Optional, Sequence, Tuple, Union
1010

1111
import torch
1212
from torch import nn
@@ -282,3 +282,101 @@ def _from_torchvision(
282282
)
283283

284284
return cls(size=tv_random_crop.size)
285+
286+
287+
def _make_transform_specs(
288+
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
289+
input_dims: Tuple[Optional[int], Optional[int]],
290+
) -> str:
291+
"""Given a sequence of transforms, turn those into the specification string
292+
the core API expects.
293+
294+
Args:
295+
transforms: Optional sequence of transform objects. The objects can be
296+
one of two types:
297+
1. torchcodec.transforms.DecoderTransform
298+
2. torchvision.transforms.v2.Transform, but our type annotation
299+
only mentions its base, nn.Module. We don't want to take a
300+
hard dependency on TorchVision.
301+
input_dims: Optional (height, width) pair. Note that only some
302+
transforms need to know the dimensions. If the user provides
303+
transforms that don't need to know the dimensions, and that metadata
304+
is missing, everything should still work. That means we assert their
305+
existence as late as possible.
306+
307+
Returns:
308+
String of transforms in the format the core API expects: transform
309+
specifications separate by semicolons.
310+
"""
311+
if transforms is None:
312+
return ""
313+
314+
try:
315+
from torchvision.transforms import v2
316+
317+
tv_available = True
318+
except ImportError:
319+
tv_available = False
320+
321+
# The following loop accomplishes two tasks:
322+
#
323+
# 1. Converts the transform to a DecoderTransform, if necessary. We
324+
# accept TorchVision transform objects and they must be converted
325+
# to their matching DecoderTransform.
326+
# 2. Calculates what the input dimensions are to each transform.
327+
#
328+
# The order in our transforms list is semantically meaningful, as we
329+
# actually have a pipeline where the output of one transform is the input to
330+
# the next. For example, if we have the transforms list [A, B, C, D], then
331+
# we should understand that as:
332+
#
333+
# A -> B -> C -> D
334+
#
335+
# Where the frame produced by A is the input to B, the frame produced by B
336+
# is the input to C, etc. This particularly matters for frame dimensions.
337+
# Transforms can both:
338+
#
339+
# 1. Produce frames with arbitrary dimensions.
340+
# 2. Rely on their input frame's dimensions to calculate ahead-of-time
341+
# what their runtime behavior will be.
342+
#
343+
# The consequence of the above facts is that we need to statically track
344+
# frame dimensions in the pipeline while we pre-process it. The input
345+
# frame's dimensions to A, our first transform, is always what we know from
346+
# our metadata. For each transform, we always calculate its output
347+
# dimensions from its input dimensions. We store these with the converted
348+
# transform, to be all used together when we generate the specs.
349+
converted_transforms: list[
350+
Tuple[
351+
DecoderTransform,
352+
# A (height, width) pair where the values may be missing.
353+
Tuple[Optional[int], Optional[int]],
354+
]
355+
] = []
356+
curr_input_dims = input_dims
357+
for transform in transforms:
358+
if not isinstance(transform, DecoderTransform):
359+
if not tv_available:
360+
raise ValueError(
361+
f"The supplied transform, {transform}, is not a TorchCodec "
362+
" DecoderTransform. TorchCodec also accepts TorchVision "
363+
"v2 transforms, but TorchVision is not installed."
364+
)
365+
elif isinstance(transform, v2.Resize):
366+
transform = Resize._from_torchvision(transform)
367+
elif isinstance(transform, v2.CenterCrop):
368+
transform = CenterCrop._from_torchvision(transform)
369+
elif isinstance(transform, v2.RandomCrop):
370+
transform = RandomCrop._from_torchvision(transform)
371+
else:
372+
raise ValueError(
373+
f"Unsupported transform: {transform}. Transforms must be "
374+
"either a TorchCodec DecoderTransform or a TorchVision "
375+
"v2 transform."
376+
)
377+
378+
converted_transforms.append((transform, curr_input_dims))
379+
output_dims = transform._get_output_dims()
380+
curr_input_dims = output_dims if output_dims is not None else curr_input_dims
381+
382+
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])

0 commit comments

Comments
 (0)