Skip to content

Commit 49e4187

Browse files
committed
Ensure correct encoding for non-contiguous WF
1 parent 99d21db commit 49e4187

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ torch::Tensor validateWf(torch::Tensor wf) {
1313
wf.dtype() == torch::kFloat32,
1414
"waveform must have float32 dtype, got ",
1515
wf.dtype());
16-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17-
// planar (fltp).
1816
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
19-
return wf;
17+
return wf.contiguous();
2018
}
2119

2220
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {

test/test_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,35 @@ def test_encode_to_tensor_long_output(self):
12841284

12851285
torch.testing.assert_close(self.decode(encoded_tensor), samples)
12861286

1287+
def test_contiguity(self):
1288+
num_samples = 10_000 # per channel
1289+
contiguous_samples = torch.rand(2, num_samples).contiguous()
1290+
assert contiguous_samples.stride() == (num_samples, 1)
1291+
1292+
encoded_from_contiguous = encode_audio_to_tensor(
1293+
wf=contiguous_samples,
1294+
sample_rate=16_000,
1295+
format="flac",
1296+
bit_rate=44_000,
1297+
)
1298+
non_contiguous_samples = contiguous_samples.T.contiguous().T
1299+
assert non_contiguous_samples.stride() == (1, 2)
1300+
1301+
torch.testing.assert_close(
1302+
contiguous_samples, non_contiguous_samples, rtol=0, atol=0
1303+
)
1304+
1305+
encoded_from_non_contiguous = encode_audio_to_tensor(
1306+
wf=non_contiguous_samples,
1307+
sample_rate=16_000,
1308+
format="flac",
1309+
bit_rate=44_000,
1310+
)
1311+
1312+
torch.testing.assert_close(
1313+
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
1314+
)
1315+
12871316

12881317
if __name__ == "__main__":
12891318
pytest.main()

0 commit comments

Comments
 (0)