Skip to content

Commit d1c2fe4

Browse files
committed
Revert "Add utility to get computed kernel in torch.library (pytorch#158393)"
This reverts commit 1196bb1. (cherry picked from commit a975dfe)
1 parent 3004b5d commit d1c2fe4

File tree

10 files changed

+0
-413
lines changed

10 files changed

+0
-413
lines changed

aten/src/ATen/core/boxing/KernelFunction.h

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include <c10/core/DispatchKeySet.h>
77
#include <c10/util/TypeList.h>
88
#include <c10/util/intrusive_ptr.h>
9-
#include <atomic>
10-
#include <memory>
119
#include <type_traits>
1210

1311
namespace c10 {
@@ -19,9 +17,6 @@ class OperatorHandle;
1917
struct OperatorKernel;
2018
class KernelFunction;
2119

22-
class KernelToken;
23-
class SafeKernelFunction;
24-
2520
template <typename T>
2621
using has_symint = std::disjunction<
2722
std::is_same<c10::SymInt, T>,
@@ -95,12 +90,6 @@ class TORCH_API KernelFunction final {
9590
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
9691

9792
KernelFunction();
98-
~KernelFunction();
99-
100-
KernelFunction(const KernelFunction&) = default;
101-
KernelFunction& operator=(const KernelFunction&) = default;
102-
103-
KernelFunction(KernelFunction&&) noexcept = default;
10493

10594
// Fast path for dispatch to allow not touching the boxed kernel in
10695
// the common case where unboxed is available.
@@ -273,13 +262,6 @@ class TORCH_API KernelFunction final {
273262
// For testing internal invariants only
274263
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
275264

276-
// Register a token to be invalidated when this KernelFunction is destroyed
277-
void registerToken(std::weak_ptr<KernelToken> token) const;
278-
279-
// List of tokens that need to be invalidated when this KernelFunction is
280-
// destroyed
281-
mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
282-
283265
private:
284266
explicit KernelFunction(
285267
std::unique_ptr<OperatorKernel> functor,
@@ -296,47 +278,6 @@ class TORCH_API KernelFunction final {
296278
void* sym_unboxed_kernel_func_;
297279
};
298280

299-
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
300-
// destroyed
301-
class KernelToken {
302-
public:
303-
bool isValid() const;
304-
void invalidate();
305-
306-
private:
307-
std::atomic<bool> invalid_{false};
308-
};
309-
310-
class SafeKernelFunction {
311-
public:
312-
SafeKernelFunction(
313-
const KernelFunction* kernel,
314-
std::string debug,
315-
std::shared_ptr<OperatorHandle> opHandle);
316-
317-
// Safe callBoxed - checks token validity first
318-
void callBoxed(
319-
const OperatorHandle& opHandle,
320-
DispatchKeySet dispatchKeySet,
321-
Stack* stack) const;
322-
323-
// Get debug information
324-
const std::string& debug() const {
325-
return debug_;
326-
}
327-
328-
// Get the OpHandle that lives on this SafeKernelFunction
329-
const OperatorHandle& opHandle() const {
330-
return *opHandle_;
331-
}
332-
333-
private:
334-
KernelFunction kernel_;
335-
std::shared_ptr<KernelToken> token_;
336-
std::string debug_;
337-
std::shared_ptr<OperatorHandle> opHandle_;
338-
};
339-
340281
} // namespace c10
341282

342283
#include <ATen/core/boxing/KernelFunction_impl.h>

aten/src/ATen/core/boxing/KernelFunction_impl.h

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ inline KernelFunction::KernelFunction()
2424
unboxed_kernel_func_(nullptr),
2525
sym_unboxed_kernel_func_(nullptr) {}
2626

27-
inline KernelFunction::~KernelFunction() {
28-
for (auto& weak_token : tokens_) {
29-
if (auto token = weak_token.lock()) {
30-
token->invalidate();
31-
}
32-
}
33-
}
34-
3527
inline KernelFunction::KernelFunction(
3628
std::unique_ptr<OperatorKernel> functor,
3729
InternalBoxedKernelFunction* boxed_kernel_func,
@@ -165,11 +157,6 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
165157
std::forward<Args>(args)...);
166158
}
167159

168-
inline void KernelFunction::registerToken(
169-
std::weak_ptr<KernelToken> token) const {
170-
tokens_.push_back(std::move(token));
171-
}
172-
173160
inline KernelFunction KernelFunction::makeFromBoxedKernel(
174161
BoxedKernel boxed_fn) {
175162
return KernelFunction(
@@ -330,38 +317,4 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
330317
std::forward<Lambda>(lambda)));
331318
}
332319

333-
inline bool KernelToken::isValid() const {
334-
return !invalid_.load(std::memory_order_acquire);
335-
}
336-
337-
inline void KernelToken::invalidate() {
338-
invalid_.store(true, std::memory_order_release);
339-
}
340-
341-
inline SafeKernelFunction::SafeKernelFunction(
342-
const KernelFunction* kernel,
343-
std::string debug,
344-
std::shared_ptr<OperatorHandle> opHandle)
345-
: kernel_(kernel ? *kernel : KernelFunction()),
346-
token_(std::make_shared<KernelToken>()),
347-
debug_(std::move(debug)),
348-
opHandle_(std::move(opHandle)) {
349-
// Register the token with the original kernel so it gets invalidated when the
350-
// kernel is destroyed
351-
if (kernel) {
352-
kernel->registerToken(token_);
353-
}
354-
}
355-
356-
inline void SafeKernelFunction::callBoxed(
357-
const OperatorHandle& opHandle,
358-
DispatchKeySet dispatchKeySet,
359-
Stack* stack) const {
360-
TORCH_CHECK(
361-
token_ && token_->isValid(),
362-
"SafeKernelFunction has been invalidated ",
363-
debug_);
364-
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
365-
}
366-
367320
} // namespace c10

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,6 @@ class TORCH_API OperatorHandle {
487487
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
488488
}
489489

490-
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
491-
return operatorDef_->op.getComputedKernelForDispatchKey(k);
492-
}
493-
494490
std::string dumpComputedTable() const {
495491
return operatorDef_->op.dumpComputedTable();
496492
}

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -315,42 +315,6 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
315315
return nullptr;
316316
}
317317

318-
SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
319-
DispatchKey k) const {
320-
TORCH_CHECK(
321-
!isAliasDispatchKey(k),
322-
"Alias keys do not have runtime kernel registrations.");
323-
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
324-
TORCH_CHECK(
325-
dispatchTable_[dispatch_ix].isValid(),
326-
"no kernel for ",
327-
k,
328-
" for ",
329-
name_);
330-
331-
// Get the KernelFunction object from kernels_ to pass to SafeKernelFunction
332-
333-
// The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
334-
// in the AnnotatedKernel in kernels_. A KernelFunction is only truly
335-
// deregistered when the kernel is removed from kernels_. However, the
336-
// KernelFunction in dispatchTable_ might be removed before it is deregistered
337-
// (when a newer kernel is registered). Therefore, here we want to return a
338-
// SafeKernelFunction that is backed by the original KernelFunction in
339-
// kernels_, so that we only invalidate it when the kernel is deregistered.
340-
auto [annotatedKernel, _] =
341-
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
342-
343-
// Use findSchemaOrThrow to get OpHandle for the OperatorEntry
344-
auto& dispatcher = c10::Dispatcher::singleton();
345-
auto opHandle = dispatcher.findSchemaOrThrow(
346-
name_.name.c_str(), name_.overload_name.c_str());
347-
348-
return SafeKernelFunction(
349-
&annotatedKernel.kernel,
350-
annotatedKernel.debug,
351-
std::make_shared<OperatorHandle>(opHandle));
352-
}
353-
354318
const std::vector<at::Tag>& OperatorEntry::getTags() const {
355319
#if defined C10_MOBILE
356320
TORCH_CHECK(false, "tags are not saved for Mobile");

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,6 @@ class TORCH_API OperatorEntry final {
217217
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
218218
// Returns true if the "computed table" has an entry for a particular key.
219219
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
220-
// Returns a KernelFunction corresponding to the kernel in dispatchTable
221-
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
222220
// Returns all the operator tags added at the time of registration
223221
const std::vector<at::Tag>& getTags() const;
224222
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);

docs/source/library.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ via PyTorch's C++ operator registration APIs).
5656
.. autofunction:: infer_schema
5757
.. autoclass:: torch._library.custom_ops.CustomOpDef
5858
:members: set_kernel_enabled
59-
.. autofunction:: get_kernel
6059
```
6160

6261
## Low-level APIs

test/test_custom_ops.py

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import tempfile
1212
import typing
1313
import unittest
14-
from functools import partial
1514
from pathlib import Path
1615
from typing import * # noqa: F403
1716

@@ -4157,148 +4156,6 @@ def test_any_output_is_alias_to_input_or_output(self):
41574156
)
41584157
)
41594158

4160-
def test_library_get_kernel(self):
4161-
"""Test registering a custom kernel, using it, then deregistering and verifying error."""
4162-
4163-
# Register a dummy kernel for arange to the CPU key that returns a tensor of ones
4164-
def dummy_arange_cpu(
4165-
dispatch_keys,
4166-
start,
4167-
end,
4168-
dtype=None,
4169-
layout=torch.strided,
4170-
device=None,
4171-
pin_memory=False,
4172-
):
4173-
size = max(0, int(end - start))
4174-
return torch.ones(size, dtype=dtype, device=device)
4175-
4176-
with torch.library._scoped_library("aten", "IMPL") as lib:
4177-
lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True)
4178-
4179-
kernel = torch.library.get_kernel("aten::arange.start", "CPU")
4180-
dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
4181-
result = kernel.call_boxed(dispatch_keys, 0, 5)
4182-
4183-
self.assertEqual(result, torch.ones(5))
4184-
4185-
# The kernel should now be invalidated after exiting the scoped_library context
4186-
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
4187-
kernel.call_boxed(dispatch_keys, 0, 5)
4188-
4189-
def test_library_get_kernel_with_conditional_dispatch(self):
4190-
"""Test registering a custom kernel with conditional dispatch logic."""
4191-
4192-
def conditional_arange_cpu1(
4193-
original_kernel,
4194-
dispatch_keys,
4195-
start,
4196-
end,
4197-
dtype=None,
4198-
layout=torch.strided,
4199-
device=None,
4200-
pin_memory=False,
4201-
):
4202-
# If end is even, use the original kernel, otherwise return ones tensor
4203-
if end % 2 == 0:
4204-
op_handle = torch.ops.aten.arange.start._handle
4205-
return original_kernel.call_boxed(
4206-
dispatch_keys,
4207-
start,
4208-
end,
4209-
dtype=dtype,
4210-
layout=layout,
4211-
device=device,
4212-
pin_memory=pin_memory,
4213-
)
4214-
else:
4215-
size = max(0, int(end - start))
4216-
return torch.ones(size, dtype=dtype, device=device)
4217-
4218-
def conditional_arange_cpu2(
4219-
original_kernel,
4220-
dispatch_keys,
4221-
start,
4222-
end,
4223-
dtype=None,
4224-
layout=torch.strided,
4225-
device=None,
4226-
pin_memory=False,
4227-
):
4228-
# If start is even, use the original kernel, otherwise return twos tensor
4229-
if start % 2 == 0:
4230-
op_handle = torch.ops.aten.arange.start._handle
4231-
return original_kernel.call_boxed(
4232-
dispatch_keys,
4233-
start,
4234-
end,
4235-
dtype=dtype,
4236-
layout=layout,
4237-
device=device,
4238-
pin_memory=pin_memory,
4239-
)
4240-
else:
4241-
size = max(0, int(end - start))
4242-
return torch.empty(size, dtype=dtype, device=device).fill_(2)
4243-
4244-
original_kernel = torch.library.get_kernel("aten::arange.start", "CPU")
4245-
expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6)
4246-
expected_result3, expected_result4, expected_result5 = (
4247-
torch.ones(5),
4248-
torch.arange(0, 6),
4249-
torch.ones(5).fill_(2),
4250-
)
4251-
4252-
with torch.library._scoped_library("aten", "IMPL") as lib2:
4253-
with torch.library._scoped_library("aten", "IMPL") as lib1:
4254-
lib1.impl(
4255-
"arange.start",
4256-
partial(conditional_arange_cpu1, original_kernel),
4257-
"CPU",
4258-
with_keyset=True,
4259-
)
4260-
4261-
self.assertEqual(torch.arange(0, 5), expected_result1)
4262-
self.assertEqual(torch.arange(0, 6), expected_result2)
4263-
new_original_kernel = torch.library.get_kernel(
4264-
"aten::arange.start", "CPU"
4265-
)
4266-
lib2.impl(
4267-
"arange.start",
4268-
partial(conditional_arange_cpu2, new_original_kernel),
4269-
"CPU",
4270-
allow_override=True,
4271-
with_keyset=True,
4272-
)
4273-
4274-
self.assertEqual(torch.arange(0, 5), expected_result3)
4275-
self.assertEqual(torch.arange(0, 6), expected_result4)
4276-
self.assertEqual(torch.arange(1, 6), expected_result5)
4277-
4278-
# The kernel should now be invalidated after destroying lib1
4279-
with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
4280-
torch.arange(0, 5)
4281-
4282-
# Should still work after destroying lib1
4283-
self.assertEqual(torch.arange(1, 6), expected_result5)
4284-
4285-
def test_library_get_kernel_invalid(self):
4286-
"""Test that get_kernel raises an error when no kernel is available."""
4287-
with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib:
4288-
lib.define("cpu_only_op(Tensor x) -> Tensor")
4289-
lib.impl("cpu_only_op", lambda x: x * 2, "CPU")
4290-
4291-
cpu_kernel = torch.library.get_kernel(
4292-
"test_invalid_kernel::cpu_only_op", "CPU"
4293-
)
4294-
self.assertIsNotNone(cpu_kernel)
4295-
4296-
# CUDA should fail at the isValid() check since no CUDA kernel exists
4297-
with self.assertRaisesRegex(
4298-
RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op"
4299-
):
4300-
torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA")
4301-
43024159

43034160
class MiniOpTestOther(CustomOpTestCaseBase):
43044161
test_ns = "mini_op_test"

0 commit comments

Comments
 (0)