Skip to content

Commit 1bd0c85

Browse files
authored
Fix implicit conversion of int64 to int (#856)
1 parent ef21ee8 commit 1bd0c85

File tree

5 files changed

+80
-29
lines changed

5 files changed

+80
-29
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ function(make_torchcodec_libraries
9393
CpuDeviceInterface.cpp
9494
SingleStreamDecoder.cpp
9595
Encoder.cpp
96+
ValidationUtils.cpp
9697
)
9798

9899
if(ENABLE_CUDA)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/ValidationUtils.h"
8+
#include <limits>
9+
#include "c10/util/Exception.h"
10+
11+
namespace facebook::torchcodec {
12+
13+
int validateInt64ToInt(int64_t value, const std::string& parameterName) {
14+
TORCH_CHECK(
15+
value >= std::numeric_limits<int>::min() &&
16+
value <= std::numeric_limits<int>::max(),
17+
parameterName,
18+
"=",
19+
value,
20+
" is out of range for int type.");
21+
22+
return static_cast<int>(value);
23+
}
24+
25+
std::optional<int> validateOptionalInt64ToInt(
26+
const std::optional<int64_t>& value,
27+
const std::string& parameterName) {
28+
if (value.has_value()) {
29+
return validateInt64ToInt(value.value(), parameterName);
30+
} else {
31+
return std::nullopt;
32+
}
33+
}
34+
35+
} // namespace facebook::torchcodec
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <cstdint>
10+
#include <optional>
11+
#include <string>
12+
13+
namespace facebook::torchcodec {
14+
15+
int validateInt64ToInt(int64_t value, const std::string& parameterName);
16+
17+
std::optional<int> validateOptionalInt64ToInt(
18+
const std::optional<int64_t>& value,
19+
const std::string& parameterName);
20+
21+
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "src/torchcodec/_core/AVIOTensorContext.h"
1414
#include "src/torchcodec/_core/Encoder.h"
1515
#include "src/torchcodec/_core/SingleStreamDecoder.h"
16+
#include "src/torchcodec/_core/ValidationUtils.h"
1617

1718
namespace facebook::torchcodec {
1819

@@ -164,16 +165,6 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
164165
return ss.str();
165166
}
166167

167-
int validateSampleRate(int64_t sampleRate) {
168-
TORCH_CHECK(
169-
sampleRate <= std::numeric_limits<int>::max(),
170-
"sample_rate=",
171-
sampleRate,
172-
" is too large to be cast to an int.");
173-
174-
return static_cast<int>(sampleRate);
175-
}
176-
177168
} // namespace
178169

179170
// ==============================
@@ -413,14 +404,17 @@ void encode_audio_to_file(
413404
std::optional<int64_t> bit_rate = std::nullopt,
414405
std::optional<int64_t> num_channels = std::nullopt,
415406
std::optional<int64_t> desired_sample_rate = std::nullopt) {
416-
// TODO Fix implicit int conversion:
417-
// https://github.com/pytorch/torchcodec/issues/679
418407
AudioStreamOptions audioStreamOptions;
419-
audioStreamOptions.bitRate = bit_rate;
420-
audioStreamOptions.numChannels = num_channels;
421-
audioStreamOptions.sampleRate = desired_sample_rate;
408+
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
409+
audioStreamOptions.numChannels =
410+
validateOptionalInt64ToInt(num_channels, "num_channels");
411+
audioStreamOptions.sampleRate =
412+
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
422413
AudioEncoder(
423-
samples, validateSampleRate(sample_rate), file_name, audioStreamOptions)
414+
samples,
415+
validateInt64ToInt(sample_rate, "sample_rate"),
416+
file_name,
417+
audioStreamOptions)
424418
.encode();
425419
}
426420

@@ -432,15 +426,15 @@ at::Tensor encode_audio_to_tensor(
432426
std::optional<int64_t> num_channels = std::nullopt,
433427
std::optional<int64_t> desired_sample_rate = std::nullopt) {
434428
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
435-
// TODO Fix implicit int conversion:
436-
// https://github.com/pytorch/torchcodec/issues/679
437429
AudioStreamOptions audioStreamOptions;
438-
audioStreamOptions.bitRate = bit_rate;
439-
audioStreamOptions.numChannels = num_channels;
440-
audioStreamOptions.sampleRate = desired_sample_rate;
430+
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
431+
audioStreamOptions.numChannels =
432+
validateOptionalInt64ToInt(num_channels, "num_channels");
433+
audioStreamOptions.sampleRate =
434+
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
441435
return AudioEncoder(
442436
samples,
443-
validateSampleRate(sample_rate),
437+
validateInt64ToInt(sample_rate, "sample_rate"),
444438
format,
445439
std::move(avioContextHolder),
446440
audioStreamOptions)

src/torchcodec/_core/pybind_ops.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "src/torchcodec/_core/Encoder.h"
1414
#include "src/torchcodec/_core/SingleStreamDecoder.h"
1515
#include "src/torchcodec/_core/StreamOptions.h"
16+
#include "src/torchcodec/_core/ValidationUtils.h"
1617

1718
namespace py = pybind11;
1819

@@ -55,20 +56,19 @@ void encode_audio_to_file_like(
5556
auto samples = torch::from_blob(
5657
reinterpret_cast<void*>(data_ptr), shape, tensor_options);
5758

58-
// TODO Fix implicit int conversion:
59-
// https://github.com/pytorch/torchcodec/issues/679
60-
// same for sample_rate parameter below
6159
AudioStreamOptions audioStreamOptions;
62-
audioStreamOptions.bitRate = bit_rate;
63-
audioStreamOptions.numChannels = num_channels;
64-
audioStreamOptions.sampleRate = desired_sample_rate;
60+
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
61+
audioStreamOptions.numChannels =
62+
validateOptionalInt64ToInt(num_channels, "num_channels");
63+
audioStreamOptions.sampleRate =
64+
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
6565

6666
auto avioContextHolder =
6767
std::make_unique<AVIOFileLikeContext>(file_like, /*isForWriting=*/true);
6868

6969
AudioEncoder encoder(
7070
samples,
71-
static_cast<int>(sample_rate),
71+
validateInt64ToInt(sample_rate, "sample_rate"),
7272
format,
7373
std::move(avioContextHolder),
7474
audioStreamOptions);

0 commit comments

Comments
 (0)