33import pathlib
44import sys
55
6- from dataclasses import dataclass
7- from typing import Dict , Optional
6+ from dataclasses import dataclass , field
7+ from typing import Dict , List , Optional
88
99import numpy as np
1010import pytest
@@ -203,6 +203,8 @@ class TestVideoStreamInfo:
203203
204204@dataclass
205205class TestVideo (TestContainerFile ):
206+ """Base class for the *video* streams of a video container"""
207+
206208 stream_infos : Dict [int , TestVideoStreamInfo ]
207209
208210 def get_frame_data_by_index (
@@ -318,13 +320,16 @@ class TestAudioStreamInfo:
318320 sample_rate : int
319321 num_channels : int
320322 duration_seconds : float
323+ num_frames : int
321324
322325
323326@dataclass
324327class TestAudio (TestContainerFile ):
328+ """Base class for the *audio* streams of a container (potentially a video),
329+ or a pure audio file"""
325330
326331 stream_infos : Dict [int , TestAudioStreamInfo ]
327- _reference_frames : tuple [ torch .Tensor ] = tuple ( )
332+ _reference_frames : Dict [ int , List [ torch .Tensor ]] = field ( default_factory = dict )
328333
329334 # Storing each individual frame is too expensive for audio, because there's
330335 # a massive overhead in the binary format saved by pytorch. Saving all the
@@ -333,32 +338,22 @@ class TestAudio(TestContainerFile):
333338 # So we store the reference frames in a single file, and load/cache those
334339 # when the TestAudio instance is created.
335340 def __post_init__ (self ):
336- # We hard-code the default stream index, see TODO below.
337- file_path = _get_file_path (
338- f"{ self .filename } .stream{ self .default_stream_index } .all_frames.pt"
339- )
340- if not file_path .exists ():
341- return # TODO-audio
342- t = torch .load (file_path , weights_only = True )
341+ for stream_index in self .stream_infos :
342+ file_path = _get_file_path (
343+ f"{ self .filename } .stream{ stream_index } .all_frames.pt"
344+ )
343345
344- # These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each
345- # of the 204 frames contains 1024 samples.
346- # TODO make this more generic
347- assert t .shape == (2 , 204 * 1024 )
348- self ._reference_frames = torch .chunk (t , chunks = 204 , dim = 1 )
346+ self ._reference_frames [stream_index ] = torch .load (
347+ file_path , weights_only = True
348+ )
349349
350350 def get_frame_data_by_index (
351351 self , idx : int , * , stream_index : Optional [int ] = None
352352 ) -> torch .Tensor :
353- if stream_index is not None and stream_index != self .default_stream_index :
354- # TODO address this, the fix should be to let _reference_frames be a
355- # dict[tuple[torch.Tensor]] where keys are stream indices, and load
356- # all of those indices in __post_init__.
357- raise ValueError (
358- "Can only use default stream index with TestAudio for now."
359- )
353+ if stream_index is None :
354+ stream_index = self .default_stream_index
360355
361- return self ._reference_frames [idx ]
356+ return self ._reference_frames [stream_index ][ idx ]
362357
363358 def pts_to_frame_index (self , pts_seconds : float ) -> int :
364359 # These are hard-coded value assuming stream 4 of nasa_13013.mp4. Each
@@ -379,10 +374,9 @@ def num_channels(self) -> int:
379374 def duration_seconds (self ) -> float :
380375 return self .stream_infos [self .default_stream_index ].duration_seconds
381376
382- # TODO: this shouldn't be named chw. Also values are hard-coded
383377 @property
384- def empty_chw_tensor (self ) -> torch . Tensor :
385- return torch . empty ([ 0 , 2 , 1024 ], dtype = torch . float32 )
378+ def num_frames (self ) -> int :
379+ return self . stream_infos [ self . default_stream_index ]. num_frames
386380
387381
388382NASA_AUDIO_MP3 = TestAudio (
@@ -391,7 +385,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
391385 frames = {}, # TODO
392386 stream_infos = {
393387 0 : TestAudioStreamInfo (
394- sample_rate = 8_000 , num_channels = 2 , duration_seconds = 13.248
388+ sample_rate = 8_000 , num_channels = 2 , duration_seconds = 13.248 , num_frames = 183
395389 )
396390 },
397391)
@@ -402,7 +396,7 @@ def empty_chw_tensor(self) -> torch.Tensor:
402396 frames = {}, # TODO
403397 stream_infos = {
404398 4 : TestAudioStreamInfo (
405- sample_rate = 16_000 , num_channels = 2 , duration_seconds = 13.056
399+ sample_rate = 16_000 , num_channels = 2 , duration_seconds = 13.056 , num_frames = 204
406400 )
407401 },
408402)
0 commit comments