Skip to content

Commit 2b4ef6b

Browse files
angelayipytorchmergebot
authored andcommitted
[opaque_obj_v2] PyObject custom op schema type (pytorch#165004)
This is a cleaner implementation of opaque objects (pytorch#162660). Instead now we just need to do: Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type. ```python class OpaqueQueue: def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None: super().__init__() self.queue = queue self.init_tensor_ = init_tensor_ def push(self, tensor: torch.Tensor) -> None: self.queue.append(tensor) def pop(self) -> torch.Tensor: if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: return len(self.queue) register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") ``` When creating the custom op, the schema will then use the unique name: ```python self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") torch.library.define( "_TestOpaqueObject::queue_push", "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", tags=torch.Tag.pt2_compliant_tag, lib=self.lib, ) @torch.library.impl( "_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib ) def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: assert isinstance(queue, OpaqueQueue) queue.push(b) ``` Using the custom op: ```python queue = OpaqueQueue([], torch.zeros(3)) torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3)) self.assertTrue(queue.size(), 1) ``` Pull Request resolved: pytorch#165004 Approved by: https://github.com/albanD
1 parent 3f83e89 commit 2b4ef6b

File tree

7 files changed

+170
-4
lines changed

7 files changed

+170
-4
lines changed

test/test_opaque_obj_v2.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Owner(s): ["module: custom-operators"]
2+
3+
import torch
4+
from torch._dynamo.test_case import run_tests, TestCase
5+
from torch._library.opaque_object import register_opaque_type
6+
7+
8+
class OpaqueQueue:
9+
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
10+
super().__init__()
11+
self.queue = queue
12+
self.init_tensor_ = init_tensor_
13+
14+
def push(self, tensor: torch.Tensor) -> None:
15+
self.queue.append(tensor)
16+
17+
def pop(self) -> torch.Tensor:
18+
if len(self.queue) > 0:
19+
return self.queue.pop(0)
20+
return self.init_tensor_
21+
22+
def size(self) -> int:
23+
return len(self.queue)
24+
25+
26+
class TestOpaqueObject(TestCase):
27+
def setUp(self):
28+
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
29+
30+
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
31+
32+
torch.library.define(
33+
"_TestOpaqueObject::queue_push",
34+
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
35+
tags=torch.Tag.pt2_compliant_tag,
36+
lib=self.lib,
37+
)
38+
39+
@torch.library.impl(
40+
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
41+
)
42+
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
43+
assert isinstance(queue, OpaqueQueue)
44+
queue.push(b)
45+
46+
self.lib.define(
47+
"queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor",
48+
)
49+
50+
def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
51+
assert isinstance(queue, OpaqueQueue)
52+
return queue.pop()
53+
54+
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
55+
56+
@torch.library.custom_op(
57+
"_TestOpaqueObject::queue_size",
58+
mutates_args=[],
59+
)
60+
def size_impl(queue: OpaqueQueue) -> int:
61+
assert isinstance(queue, OpaqueQueue)
62+
return queue.size()
63+
64+
super().setUp()
65+
66+
def tearDown(self):
67+
self.lib._destroy()
68+
69+
super().tearDown()
70+
71+
def test_ops(self):
72+
queue = OpaqueQueue([], torch.zeros(3))
73+
74+
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3) + 1)
75+
size = torch.ops._TestOpaqueObject.queue_size(queue)
76+
self.assertEqual(size, 1)
77+
popped = torch.ops._TestOpaqueObject.queue_pop(queue)
78+
self.assertEqual(popped, torch.ones(3) + 1)
79+
size = torch.ops._TestOpaqueObject.queue_size(queue)
80+
self.assertEqual(size, 0)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,8 @@ def _jit_pass_lint(Graph) -> None: ...
16271627
def _make_opaque_object(payload: Any) -> ScriptObject: ...
16281628
def _get_opaque_object_payload(obj: ScriptObject) -> Any: ...
16291629
def _set_opaque_object_payload(obj: ScriptObject, payload: Any) -> None: ...
1630+
def _register_opaque_type(type_name: str) -> None: ...
1631+
def _is_opaque_type_registered(type_name: str) -> _bool: ...
16301632

16311633
# Defined in torch/csrc/jit/python/python_custom_class.cpp
16321634
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...

torch/_library/infer_schema.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import device, dtype, Tensor, types
1010
from torch.utils._exposed_in import exposed_in
1111

12-
from .opaque_object import OpaqueType, OpaqueTypeStr
12+
from .opaque_object import _OPAQUE_TYPES, is_opaque_type, OpaqueType, OpaqueTypeStr
1313

1414

1515
# This is used as a negative test for
@@ -125,8 +125,11 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
125125
# we convert it to the actual type.
126126
annotation_type, _ = unstringify_type(param.annotation)
127127

128+
schema_type = None
128129
if annotation_type not in SUPPORTED_PARAM_TYPES:
129-
if annotation_type == torch._C.ScriptObject:
130+
if is_opaque_type(annotation_type):
131+
schema_type = _OPAQUE_TYPES[annotation_type]
132+
elif annotation_type == torch._C.ScriptObject:
130133
error_fn(
131134
f"Parameter {name}'s type cannot be inferred from the schema "
132135
"as it is a ScriptObject. Please manually specify the schema "
@@ -152,8 +155,11 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
152155
f"Parameter {name} has unsupported type {param.annotation}. "
153156
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
154157
)
158+
else:
159+
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
160+
161+
assert schema_type is not None
155162

156-
schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
157163
if type(mutates_args) is str:
158164
if mutates_args != UNKNOWN_MUTATES:
159165
raise ValueError(

torch/_library/opaque_object.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, NewType
1+
from typing import Any, NewType, Optional
22

33
import torch
44

@@ -150,3 +150,36 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
150150
f"Tried to get the payload from a non-OpaqueObject of type `{type_}`"
151151
)
152152
torch._C._set_opaque_object_payload(opaque_object, payload)
153+
154+
155+
_OPAQUE_TYPES: dict[Any, str] = {}
156+
157+
158+
def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
159+
"""
160+
Registers the given type as an opaque type which allows this to be consumed
161+
by a custom operator.
162+
163+
Args:
164+
cls (type): The class to register as an opaque type.
165+
name (str): A unique qualified name of the type.
166+
"""
167+
if name is None:
168+
name = cls.__name__
169+
170+
if "." in name:
171+
# The schema_type_parser will break up types with periods
172+
raise ValueError(
173+
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
174+
)
175+
_OPAQUE_TYPES[cls] = name
176+
torch._C._register_opaque_type(name)
177+
178+
179+
def is_opaque_type(cls: Any) -> bool:
180+
"""
181+
Checks if the given type is an opaque type.
182+
"""
183+
if cls not in _OPAQUE_TYPES:
184+
return False
185+
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

torch/csrc/jit/frontend/schema_type_parser.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/jit/frontend/parse_string_literal.h>
99
#include <torch/custom_class.h>
1010
#include <string>
11+
#include <unordered_set>
1112

1213
using c10::AliasInfo;
1314
using c10::AwaitType;
@@ -42,6 +43,25 @@ using c10::VarType;
4243

4344
namespace torch::jit {
4445

46+
static std::unordered_set<std::string>& getOpaqueTypes() {
47+
static std::unordered_set<std::string> global_opaque_types;
48+
return global_opaque_types;
49+
}
50+
51+
void registerOpaqueType(const std::string& type_name) {
52+
auto& global_opaque_types = getOpaqueTypes();
53+
auto [_, inserted] = global_opaque_types.insert(type_name);
54+
if (!inserted) {
55+
throw std::runtime_error(
56+
"Type '" + type_name + "' is already registered as an opaque type");
57+
}
58+
}
59+
60+
bool isRegisteredOpaqueType(const std::string& type_name) {
61+
auto& global_opaque_types = getOpaqueTypes();
62+
return global_opaque_types.find(type_name) != global_opaque_types.end();
63+
}
64+
4565
TypePtr SchemaTypeParser::parseBaseType() {
4666
static std::unordered_map<std::string, TypePtr> type_map = {
4767
{"Generator", c10::TypeFactory::get<GeneratorType>()},
@@ -81,6 +101,11 @@ TypePtr SchemaTypeParser::parseBaseType() {
81101
}
82102
std::string text = tok.text();
83103

104+
// Check if this type is registered as an opaque type first
105+
if (isRegisteredOpaqueType(text)) {
106+
return c10::PyObjectType::get();
107+
}
108+
84109
auto it = type_map.find(text);
85110
if (it == type_map.end()) {
86111
if (allow_typevars_ && !text.empty() && islower(text[0])) {

torch/csrc/jit/frontend/schema_type_parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ namespace torch::jit {
1010

1111
using TypePtr = c10::TypePtr;
1212

13+
TORCH_API void registerOpaqueType(const std::string& type_name);
14+
TORCH_API bool isRegisteredOpaqueType(const std::string& type_name);
15+
1316
struct TORCH_API SchemaTypeParser {
1417
TypePtr parseBaseType();
1518
std::optional<c10::AliasInfo> parseAliasAnnotation();

torch/csrc/jit/python/init.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#endif
1616
#include <c10/core/SymNodeImpl.h>
1717
#include <torch/csrc/jit/frontend/ir_emitter.h>
18+
#include <torch/csrc/jit/frontend/schema_type_parser.h>
1819
#include <torch/csrc/jit/frontend/tracer.h>
1920
#include <torch/csrc/jit/ir/irparser.h>
2021
#include <torch/csrc/jit/jit_log.h>
@@ -1890,6 +1891,18 @@ void initJITBindings(PyObject* module) {
18901891
customObj->setPayload(std::move(payload));
18911892
},
18921893
R"doc(Sets the payload of the given opaque object with the given Python object.)doc");
1894+
m.def(
1895+
"_register_opaque_type",
1896+
[](const std::string& type_name) {
1897+
torch::jit::registerOpaqueType(type_name);
1898+
},
1899+
R"doc(Registers a type name to be treated as an opaque type (PyObject) in schema parsing.)doc");
1900+
m.def(
1901+
"_is_opaque_type_registered",
1902+
[](const std::string& type_name) -> bool {
1903+
return torch::jit::isRegisteredOpaqueType(type_name);
1904+
},
1905+
R"doc(Checks if a type name is registered as an opaque type.)doc");
18931906
m.def("unify_type_list", [](const std::vector<TypePtr>& types) {
18941907
std::ostringstream s;
18951908
auto type = unifyTypeList(types, s);

0 commit comments

Comments
 (0)