Skip to content

Commit b3fc714

Browse files
committed
Add contiguity check
1 parent 9b6d9ee commit b3fc714

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

src/torchcodec/_core/ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import ctypes
87
import io
98
import json
109
import warnings
@@ -174,7 +173,6 @@ def encode_audio_to_file_like(
174173
"""
175174
assert _pybind_ops is not None
176175

177-
# Enforce float32 dtype requirement
178176
if samples.dtype != torch.float32:
179177
raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}")
180178

test/test_encoders.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -343,31 +343,45 @@ def test_encode_to_tensor_long_output(self):
343343

344344
torch.testing.assert_close(self.decode(encoded_tensor).data, samples)
345345

346-
def test_contiguity(self):
346+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
347+
def test_contiguity(self, method, tmp_path):
347348
# Ensure that 2 waveforms with the same values are encoded in the same
348349
# way, regardless of their memory layout. Here we encode 2 equal
349350
# waveforms, one is row-aligned while the other is column-aligned.
350-
# TODO: Ideally we'd be testing all encoding methods here
351351

352352
num_samples = 10_000 # per channel
353353
contiguous_samples = torch.rand(2, num_samples).contiguous()
354354
assert contiguous_samples.stride() == (num_samples, 1)
355355

356-
params = dict(format="flac", bit_rate=44_000)
357-
encoded_from_contiguous = AudioEncoder(
358-
contiguous_samples, sample_rate=16_000
359-
).to_tensor(**params)
360-
361356
non_contiguous_samples = contiguous_samples.T.contiguous().T
362357
assert non_contiguous_samples.stride() == (1, 2)
363358

364359
torch.testing.assert_close(
365360
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
366361
)
367362

368-
encoded_from_non_contiguous = AudioEncoder(
369-
non_contiguous_samples, sample_rate=16_000
370-
).to_tensor(**params)
363+
def encode_to_tensor(samples):
364+
params = dict(bit_rate=44_000)
365+
if method == "to_file":
366+
dest = str(tmp_path / "output.flac")
367+
AudioEncoder(samples, sample_rate=16_000).to_file(dest=dest, **params)
368+
with open(dest, "rb") as f:
369+
return torch.frombuffer(f.read(), dtype=torch.uint8)
370+
elif method == "to_tensor":
371+
return AudioEncoder(samples, sample_rate=16_000).to_tensor(
372+
format="flac", **params
373+
)
374+
elif method == "to_file_like":
375+
file_like = io.BytesIO()
376+
AudioEncoder(samples, sample_rate=16_000).to_file_like(
377+
file_like, format="flac", **params
378+
)
379+
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
380+
else:
381+
raise ValueError(f"Unknown method: {method}")
382+
383+
encoded_from_contiguous = encode_to_tensor(contiguous_samples)
384+
encoded_from_non_contiguous = encode_to_tensor(non_contiguous_samples)
371385

372386
torch.testing.assert_close(
373387
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0

0 commit comments

Comments
 (0)