Skip to content

Commit 915631d

Browse files
committed
Let core ops return 3D tensors
1 parent 0b17d99 commit 915631d

File tree

3 files changed

+8
-62
lines changed

3 files changed

+8
-62
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,10 +1402,6 @@ VideoDecoder::FrameBatchOutput ::FrameBatchOutput(
14021402
.dtype(torch::kFloat32)
14031403
.layout(torch::kStrided)
14041404
.device(torch::kCPU);
1405-
// Note that we allocate a 3D shape. We'll eventually return a 2D shape
1406-
// (numChannels, numSamples * numFrames) where each frame is concatenated
1407-
// along the 2nd dimension. Allocating tensors this way makes it much easier
1408-
// to use the same code paths for audio and video for batch APIs.
14091405
data = torch::empty({numFrames, numChannels, numSamples}, tensorOptions);
14101406
}
14111407

@@ -1438,15 +1434,7 @@ torch::Tensor allocateEmptyHWCTensor(
14381434
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) {
14391435
if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_AUDIO) {
14401436
// TODO_CODE_QUALITY: Do something cleaner for handling audio
1441-
if (hwcTensor.dim() == 2) {
1442-
return hwcTensor;
1443-
}
1444-
auto shape = hwcTensor.sizes();
1445-
auto numFrames = shape[0];
1446-
auto numChannels = shape[1];
1447-
auto numSamples = shape[2];
1448-
return hwcTensor.permute({1, 0, 2}).reshape(
1449-
{numChannels, numSamples * numFrames});
1437+
return hwcTensor;
14501438
}
14511439
if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder ==
14521440
"NHWC") {

test/decoders/test_video_decoder_ops.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939
from ..utils import (
4040
assert_frames_equal,
41-
contiguous_to_stacked_audio_frames,
4241
cpu_and_cuda,
4342
NASA_AUDIO,
4443
NASA_VIDEO,
@@ -237,10 +236,6 @@ def test_get_frames_at_indices(self, test_ref, device):
237236
frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180])
238237
reference_frame0 = test_ref.get_frame_data_by_index(0)
239238
reference_frame180 = test_ref.get_frame_data_by_index(180)
240-
if test_ref is NASA_AUDIO:
241-
frames0and180 = contiguous_to_stacked_audio_frames(
242-
frames0and180, num_frames=2
243-
)
244239

245240
assert_frames_equal(frames0and180[0], reference_frame0.to(device))
246241
assert_frames_equal(frames0and180[1], reference_frame180.to(device))
@@ -265,10 +260,6 @@ def test_get_frames_at_indices_unsorted_indices(self, test_ref, device):
265260
decoder,
266261
frame_indices=frame_indices,
267262
)
268-
if test_ref is NASA_AUDIO:
269-
frames = contiguous_to_stacked_audio_frames(
270-
frames, num_frames=len(frame_indices)
271-
)
272263
for frame, expected_frame in zip(frames, expected_frames):
273264
assert_frames_equal(frame, expected_frame)
274265

test/utils.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,6 @@ def cpu_and_cuda():
2323
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
2424

2525

26-
def contiguous_to_stacked_audio_frames(frames, *, num_frames):
27-
# (num_channels, num_samples * num_frames) --> (num_frames, num_channels, num_samples)
28-
# Shape conversion util for audio frame. This makes it easier to index
29-
# individual frames so we can use the same code paths when checking equality
30-
# of video frames and audio frames.
31-
num_channels = frames.shape[0]
32-
return frames.reshape(num_channels, num_frames, -1).permute(1, 0, 2)
33-
34-
3526
# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
3627
# equality. On CUDA Linux, we expect a small tolerance.
3728
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
@@ -128,7 +119,11 @@ def get_frame_data_by_range(
128119
*,
129120
stream_index: Optional[int] = None,
130121
) -> torch.Tensor:
131-
raise NotImplementedError("Override in child classes")
122+
tensors = [
123+
self.get_frame_data_by_index(i, stream_index=stream_index)
124+
for i in range(start, stop, step)
125+
]
126+
return torch.stack(tensors)
132127

133128
def get_pts_seconds_by_range(
134129
self,
@@ -202,20 +197,6 @@ def get_frame_data_by_index(
202197
)
203198
return torch.load(file_path, weights_only=True).permute(2, 0, 1)
204199

205-
def get_frame_data_by_range(
206-
self,
207-
start: int,
208-
stop: int,
209-
step: int = 1,
210-
*,
211-
stream_index: Optional[int] = None,
212-
) -> torch.Tensor:
213-
tensors = [
214-
self.get_frame_data_by_index(i, stream_index=stream_index)
215-
for i in range(start, stop, step)
216-
]
217-
return torch.stack(tensors)
218-
219200
@property
220201
def width(self) -> int:
221202
return self.stream_infos[self.default_stream_index].width
@@ -356,24 +337,10 @@ def get_frame_data_by_index(
356337

357338
return self._reference_frames[idx]
358339

359-
def get_frame_data_by_range(
360-
self,
361-
start: int,
362-
stop: int,
363-
step: int = 1,
364-
*,
365-
stream_index: Optional[int] = None,
366-
) -> torch.Tensor:
367-
tensors = [
368-
self.get_frame_data_by_index(i, stream_index=stream_index)
369-
for i in range(start, stop, step)
370-
]
371-
return torch.cat(tensors, dim=1)
372-
373-
# TODO: this shouldn't be named chw
340+
# TODO: this shouldn't be named chw. Also values are hard-coded
374341
@property
375342
def empty_chw_tensor(self) -> torch.Tensor:
376-
return torch.empty([2, 0], dtype=torch.float32)
343+
return torch.empty([0, 2, 1024], dtype=torch.float32)
377344

378345

379346
NASA_AUDIO = TestAudio(

0 commit comments

Comments
 (0)