Skip to content

Commit 3beb378

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Fix TORCH_FEATURE_VERSION guards (pytorch#167802)
This is tested by pytorch#167962 which ensures we get compilation errors when using functions that convert Device/HeaderOnlyArrayRef to StableIValue and target 2.9 Pull Request resolved: pytorch#167802 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#168025
1 parent d2ccb5b commit 3beb378

File tree

2 files changed

+90
-15
lines changed

2 files changed

+90
-15
lines changed

torch/csrc/stable/ops.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,26 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) {
269269
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
270270
}
271271

272+
// We expect this to be the stable version of the flatten.using_ints op.
273+
inline torch::stable::Tensor flatten(
274+
const torch::stable::Tensor& self,
275+
int64_t start_dim = 0,
276+
int64_t end_dim = -1) {
277+
const auto num_args = 3;
278+
std::array<StableIValue, num_args> stack{
279+
torch::stable::detail::from(self),
280+
torch::stable::detail::from(start_dim),
281+
torch::stable::detail::from(end_dim)};
282+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
283+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
284+
"aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION));
285+
#else
286+
TORCH_ERROR_CODE_CHECK(
287+
aoti_torch_call_dispatcher("aten::flatten", "using_ints", stack.data()));
288+
#endif
289+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
290+
}
291+
272292
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
273293

274294
// New ops should be added here if they use a brand new shim API
@@ -309,6 +329,8 @@ inline uint32_t get_num_threads() {
309329
// We expect this to be the stable version of the empty op that takes in
310330
// device and dtype parameters. The empty op creates a tensor with uninitialized
311331
// values of the specified size, dtype, and device.
332+
// This function is only available in 2.10 because it uses the stableivalue
333+
// conversion for HeaderOnlyArrayRef<T>, which is only available in 2.10.
312334
inline torch::stable::Tensor empty(
313335
torch::headeronly::IntHeaderOnlyArrayRef size,
314336
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt,
@@ -327,22 +349,9 @@ inline torch::stable::Tensor empty(
327349
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
328350
}
329351

330-
// We expect this to be the stable version of the flatten.using_ints op.
331-
inline torch::stable::Tensor flatten(
332-
const torch::stable::Tensor& self,
333-
int64_t start_dim = 0,
334-
int64_t end_dim = -1) {
335-
const auto num_args = 3;
336-
std::array<StableIValue, num_args> stack{
337-
torch::stable::detail::from(self),
338-
torch::stable::detail::from(start_dim),
339-
torch::stable::detail::from(end_dim)};
340-
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
341-
"aten::flatten", "using_ints", stack.data(), TORCH_ABI_VERSION));
342-
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
343-
}
344-
345352
// We expect this to be the stable version of the reshape op.
353+
// This function is only available in 2.10 because it uses the stableivalue
354+
// conversion for HeaderOnlyArrayRef<T>, which is only available in 2.10.
346355
inline torch::stable::Tensor reshape(
347356
const torch::stable::Tensor& self,
348357
torch::headeronly::IntHeaderOnlyArrayRef shape) {
@@ -355,6 +364,8 @@ inline torch::stable::Tensor reshape(
355364
}
356365

357366
// We expect this to be the stable version of the view op.
367+
// This function is only available in 2.10 because it uses the stableivalue
368+
// conversion for HeaderOnlyArrayRef<T>, which is only available in 2.10.
358369
inline torch::stable::Tensor view(
359370
const torch::stable::Tensor& self,
360371
torch::headeronly::IntHeaderOnlyArrayRef size) {

torch/csrc/stable/stableivalue_conversions.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@
1414

1515
HIDDEN_NAMESPACE_BEGIN(torch, stable, detail)
1616

17+
// Helper variable templates to detect 2.10+ types for better compile-time error
18+
// messages
19+
template <typename T>
20+
inline constexpr bool is_header_only_array_ref_v = false;
21+
22+
template <typename T>
23+
inline constexpr bool
24+
is_header_only_array_ref_v<torch::headeronly::HeaderOnlyArrayRef<T>> = true;
25+
26+
template <typename T>
27+
inline constexpr bool is_std_vector_v = false;
28+
29+
template <typename T>
30+
inline constexpr bool is_std_vector_v<std::vector<T>> = true;
31+
1732
// forward declare so that the from/to() implementations in the detail
1833
// namespace of library.h where the real work is done can compile.
1934
template <typename T>
@@ -35,6 +50,17 @@ struct FromImpl {
3550
T val,
3651
[[maybe_unused]] uint64_t extension_build_version,
3752
[[maybe_unused]] bool is_internal) {
53+
// Ensure 2.10+ types don't accidentally use the base case - provide clear
54+
// compile-time errors.
55+
static_assert(
56+
!std::is_same_v<T, torch::stable::Device>,
57+
"torch::stable::Device requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
58+
static_assert(
59+
!is_header_only_array_ref_v<T>,
60+
"HeaderOnlyArrayRef<T> requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
61+
static_assert(
62+
!is_std_vector_v<T>,
63+
"std::vector<T> requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
3864
static_assert(
3965
sizeof(T) <= sizeof(StableIValue),
4066
"StableLibrary stack does not support parameter types larger than 64 bits.");
@@ -126,6 +152,18 @@ struct FromImpl<ScalarType> {
126152
}
127153
};
128154

155+
// [Note DeviceType version guard]
156+
// This conversion was introduced in 2.10. However, we do not gate it
157+
// with TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 because this
158+
// conversion is not actually used to pass DeviceType between user
159+
// extensions and libtorch (i.e. there is no c10::TypeKind::DeviceType).
160+
// The purpose of gating other conversions is to ensure that user
161+
// extensions do not try to pass a StableIValue that libtorch is
162+
// unable to interpret.
163+
// This conversion is only used
164+
// (1) In the conversion for torch::stable::Device (already gated)
165+
// (2) Within the user extension to translate between libtorch/extension's
166+
// DeviceType (no gating needed)
129167
// Specialization for torch::headeronly::DeviceType => StableIValue
130168
// Note that we call into the shim to translate between the user's
131169
// DeviceType and libtorch's DeviceType, which can be different!
@@ -225,6 +263,11 @@ struct FromImpl<torch::stable::Tensor> {
225263
}
226264
};
227265

266+
// =============================================================================
267+
// FROM CONVERSIONS requiring TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
268+
// =============================================================================
269+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
270+
228271
// Specialization for torch::headeronly::HeaderOnlyArrayRef<T> => StableIValue
229272
// Returns a new owning reference of the underlying list.
230273
template <typename T>
@@ -287,6 +330,8 @@ struct FromImpl<torch::stable::Device> {
287330
}
288331
};
289332

333+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
334+
290335
// =============================================================================
291336
// TO CONVERSIONS (StableIValue -> T)
292337
// =============================================================================
@@ -299,6 +344,17 @@ struct ToImpl {
299344
[[maybe_unused]] uint64_t extension_build_version,
300345
[[maybe_unused]] bool is_internal) {
301346
static_assert(std::is_trivially_copyable_v<T>);
347+
// Ensure 2.10+ types don't accidentally use the base case - provide clear
348+
// compile-time errors.
349+
static_assert(
350+
!std::is_same_v<T, torch::stable::Device>,
351+
"torch::stable::Device requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
352+
static_assert(
353+
!is_header_only_array_ref_v<T>,
354+
"HeaderOnlyArrayRef<T> requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
355+
static_assert(
356+
!is_std_vector_v<T>,
357+
"std::vector<T> requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0");
302358
// T may not have a default constructor. (For example, it might be
303359
// c10::Device.) However, std::memcpy implicitly creates a T at the
304360
// destination. So, we can use a union to work around this lack of
@@ -387,6 +443,7 @@ struct ToImpl<ScalarType> {
387443
}
388444
};
389445

446+
// See [Note DeviceType version guard]
390447
// Specialization for StableIValue => torch::headeronly::DeviceType
391448
template <>
392449
struct ToImpl<DeviceType> {
@@ -467,6 +524,11 @@ struct ToImpl<torch::stable::Tensor> {
467524
}
468525
};
469526

527+
// =============================================================================
528+
// TO CONVERSIONS requiring TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
529+
// =============================================================================
530+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
531+
470532
// Specialization for StableIValue => std::vector<T>
471533
// std::vector<T> should be represented as a StableListHandle
472534
// filled with StableIValues
@@ -517,6 +579,8 @@ struct ToImpl<torch::stable::Device> {
517579
}
518580
};
519581

582+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
583+
520584
// =============================================================================
521585
// end to helpers for converting between StableIValue and T
522586
// =============================================================================

0 commit comments

Comments
 (0)