Skip to content

Commit 58e5277

Browse files
committed
Add sample_format to audio metadata
1 parent 1fd20b2 commit 58e5277

File tree

7 files changed

+30
-6
lines changed

7 files changed

+30
-6
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ void VideoDecoder::initializeDecoder() {
170170
}
171171
containerMetadata_.numVideoStreams++;
172172
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
173+
AVSampleFormat format =
174+
static_cast<AVSampleFormat>(avStream->codecpar->format);
175+
streamMetadata.sampleFormat = av_get_sample_fmt_name(format);
173176
containerMetadata_.numAudioStreams++;
174177
}
175178

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class VideoDecoder {
8181
// Audio-only fields
8282
std::optional<int64_t> sampleRate;
8383
std::optional<int64_t> numChannels;
84+
std::optional<std::string> sampleFormat;
8485
};
8586

8687
struct ContainerMetadata {

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,12 +495,15 @@ std::string get_stream_json_metadata(
495495
if (streamMetadata.numChannels.has_value()) {
496496
map["numChannels"] = std::to_string(*streamMetadata.numChannels);
497497
}
498+
if (streamMetadata.sampleFormat.has_value()) {
499+
map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value());
500+
}
498501
if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) {
499-
map["mediaType"] = "\"video\"";
502+
map["mediaType"] = quoteValue("video");
500503
} else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) {
501-
map["mediaType"] = "\"audio\"";
504+
map["mediaType"] = quoteValue("audio");
502505
} else {
503-
map["mediaType"] = "\"other\"";
506+
map["mediaType"] = quoteValue("other");
504507
}
505508
return mapToJson(map);
506509
}

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def __repr__(self):
161161
class AudioStreamMetadata(StreamMetadata):
162162
"""Metadata of a single audio stream."""
163163

164-
# TODO-AUDIO Add sample format field
165164
sample_rate: Optional[int]
166165
num_channels: Optional[int]
166+
sample_format: Optional[str]
167167

168168
def __repr__(self):
169169
return super().__repr__()
@@ -240,6 +240,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
240240
AudioStreamMetadata(
241241
sample_rate=stream_dict.get("sampleRate"),
242242
num_channels=stream_dict.get("numChannels"),
243+
sample_format=stream_dict.get("sampleFormat"),
243244
**common_meta,
244245
)
245246
)

test/decoders/test_decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ def test_metadata(self, asset):
955955
)
956956
assert decoder.metadata.sample_rate == asset.sample_rate
957957
assert decoder.metadata.num_channels == asset.num_channels
958+
assert decoder.metadata.sample_format == asset.sample_format
958959

959960
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
960961
def test_error(self, asset):

test/decoders/test_metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_get_metadata(metadata_getter):
9090
)
9191
assert best_audio_stream_metadata.bit_rate == 128837
9292
assert best_audio_stream_metadata.codec == "aac"
93+
assert best_audio_stream_metadata.sample_format == "fltp"
9394

9495

9596
@pytest.mark.parametrize(
@@ -109,6 +110,7 @@ def test_get_metadata_audio_file(metadata_getter):
109110
)
110111
assert best_audio_stream_metadata.bit_rate == 64000
111112
assert best_audio_stream_metadata.codec == "mp3"
113+
assert best_audio_stream_metadata.sample_format == "fltp"
112114

113115

114116
@pytest.mark.parametrize(

test/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class TestAudioStreamInfo:
109109
num_channels: int
110110
duration_seconds: float
111111
num_frames: int
112+
sample_format: str
112113

113114

114115
@dataclass
@@ -404,14 +405,22 @@ def duration_seconds(self) -> float:
404405
def num_frames(self) -> int:
405406
return self.stream_infos[self.default_stream_index].num_frames
406407

408+
@property
409+
def sample_format(self) -> str:
410+
return self.stream_infos[self.default_stream_index].sample_format
411+
407412

408413
NASA_AUDIO_MP3 = TestAudio(
409414
filename="nasa_13013.mp4.audio.mp3",
410415
default_stream_index=0,
411416
frames={}, # Automatically loaded from json file
412417
stream_infos={
413418
0: TestAudioStreamInfo(
414-
sample_rate=8_000, num_channels=2, duration_seconds=13.248, num_frames=183
419+
sample_rate=8_000,
420+
num_channels=2,
421+
duration_seconds=13.248,
422+
num_frames=183,
423+
sample_format="fltp",
415424
)
416425
},
417426
)
@@ -422,7 +431,11 @@ def num_frames(self) -> int:
422431
frames={}, # Automatically loaded from json file
423432
stream_infos={
424433
4: TestAudioStreamInfo(
425-
sample_rate=16_000, num_channels=2, duration_seconds=13.056, num_frames=204
434+
sample_rate=16_000,
435+
num_channels=2,
436+
duration_seconds=13.056,
437+
num_frames=204,
438+
sample_format="fltp",
426439
)
427440
},
428441
)

0 commit comments

Comments
 (0)