Skip to content

Commit 83cd626

Browse files
angelayipytorchmergebot
authored andcommitted
[opaque_obj_v2] make_fx support (pytorch#165005)
By wrapping the python objects with FakeScriptObject(FakeOpaqueQueue) we restrict users to do anything to this object. torch.compile support can be easily enabled by the rest of [this stack](pytorch#163936) and existing support for ScriptObjects. One thing to note is that by default in functionalization we mark all ops that take in FakeScriptObjects as being effectful. Should this be the case for these custom ops that take in python objs? Pull Request resolved: pytorch#165005 Approved by: https://github.com/zou3519
1 parent 5125872 commit 83cd626

File tree

4 files changed

+206
-16
lines changed

4 files changed

+206
-16
lines changed

test/test_opaque_obj_v2.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
# Owner(s): ["module: custom-operators"]
22

3+
import random
4+
35
import torch
46
from torch._dynamo.test_case import run_tests, TestCase
7+
from torch._library.fake_class_registry import FakeScriptObject
58
from torch._library.opaque_object import register_opaque_type
9+
from torch.fx.experimental.proxy_tensor import make_fx
10+
from torch.testing._internal.common_utils import (
11+
instantiate_parametrized_tests,
12+
parametrize,
13+
)
614

715

816
class OpaqueQueue:
@@ -11,24 +19,39 @@ def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> Non
1119
self.queue = queue
1220
self.init_tensor_ = init_tensor_
1321

22+
# For testing purposes
23+
self._push_counter = 0
24+
self._pop_counter = 0
25+
self._size_counter = 0
26+
1427
def push(self, tensor: torch.Tensor) -> None:
28+
self._push_counter += 1
1529
self.queue.append(tensor)
1630

1731
def pop(self) -> torch.Tensor:
32+
self._pop_counter += 1
1833
if len(self.queue) > 0:
1934
return self.queue.pop(0)
2035
return self.init_tensor_
2136

2237
def size(self) -> int:
38+
self._size_counter += 1
2339
return len(self.queue)
2440

2541

42+
class RNGState:
43+
def __init__(self, seed):
44+
self.rng = random.Random(seed)
45+
46+
47+
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
48+
register_opaque_type(RNGState, "_TestOpaqueObject_RNGState")
49+
50+
2651
class TestOpaqueObject(TestCase):
2752
def setUp(self):
2853
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
2954

30-
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
31-
3255
torch.library.define(
3356
"_TestOpaqueObject::queue_push",
3457
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
@@ -43,6 +66,10 @@ def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
4366
assert isinstance(queue, OpaqueQueue)
4467
queue.push(b)
4568

69+
@torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib)
70+
def push_impl_fake(q: OpaqueQueue, b: torch.Tensor) -> None:
71+
pass
72+
4673
self.lib.define(
4774
"queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor",
4875
)
@@ -53,6 +80,15 @@ def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
5380

5481
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
5582

83+
def pop_impl_fake(q: OpaqueQueue) -> torch.Tensor:
84+
# This is not accurate since the queue could have tensors that are
85+
# not rank 1
86+
ctx = torch.library.get_ctx()
87+
u0 = ctx.new_dynamic_size()
88+
return torch.empty(u0)
89+
90+
self.lib._register_fake("queue_pop", pop_impl_fake)
91+
5692
@torch.library.custom_op(
5793
"_TestOpaqueObject::queue_size",
5894
mutates_args=[],
@@ -61,6 +97,34 @@ def size_impl(queue: OpaqueQueue) -> int:
6197
assert isinstance(queue, OpaqueQueue)
6298
return queue.size()
6399

100+
@size_impl.register_fake
101+
def size_impl_fake(q: OpaqueQueue) -> int:
102+
ctx = torch._custom_op.impl.get_ctx()
103+
u0 = ctx.new_dynamic_size()
104+
torch._check_is_size(u0)
105+
return u0
106+
107+
torch.library.define(
108+
"_TestOpaqueObject::noisy_inject",
109+
"(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor",
110+
tags=torch.Tag.pt2_compliant_tag,
111+
lib=self.lib,
112+
)
113+
114+
@torch.library.impl(
115+
"_TestOpaqueObject::noisy_inject", "CompositeExplicitAutograd", lib=self.lib
116+
)
117+
def noisy_inject(x: torch.Tensor, rng_state: RNGState) -> torch.Tensor:
118+
assert isinstance(rng_state, RNGState)
119+
out = x.clone()
120+
for i in range(out.numel()):
121+
out.view(-1)[i] += rng_state.rng.random()
122+
return out
123+
124+
@torch.library.register_fake("_TestOpaqueObject::noisy_inject", lib=self.lib)
125+
def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor:
126+
return torch.empty_like(x)
127+
64128
super().setUp()
65129

66130
def tearDown(self):
@@ -79,6 +143,99 @@ def test_ops(self):
79143
size = torch.ops._TestOpaqueObject.queue_size(queue)
80144
self.assertEqual(size, 0)
81145

146+
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
147+
def test_make_fx(self, make_fx_tracing_mode):
148+
class M(torch.nn.Module):
149+
def forward(self, queue, x):
150+
torch.ops._TestOpaqueObject.queue_push(queue, x.tan())
151+
torch.ops._TestOpaqueObject.queue_push(queue, x.cos())
152+
torch.ops._TestOpaqueObject.queue_push(queue, x.sin())
153+
pop1 = torch.ops._TestOpaqueObject.queue_pop(queue)
154+
size1 = torch.ops._TestOpaqueObject.queue_size(queue)
155+
pop2 = torch.ops._TestOpaqueObject.queue_pop(queue)
156+
size2 = torch.ops._TestOpaqueObject.queue_size(queue)
157+
x_cos = pop1 + size1
158+
x_sin = pop2 - size2
159+
return x_sin + x_cos
160+
161+
q1 = OpaqueQueue([], torch.empty(0).fill_(-1))
162+
q2 = OpaqueQueue([], torch.empty(0).fill_(-1))
163+
164+
x = torch.ones(2, 3)
165+
gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(q1, x)
166+
self.assertTrue(torch.allclose(gm(q1, x), M()(q2, x)))
167+
self.assertEqual(q1._push_counter, 3)
168+
self.assertEqual(q1._pop_counter, 2)
169+
self.assertEqual(q1._size_counter, 2)
170+
self.assertEqual(q1.size(), 1)
171+
self.assertExpectedInline(
172+
gm.code.strip("\n"),
173+
"""\
174+
def forward(self, arg0_1, arg1_1):
175+
tan = torch.ops.aten.tan.default(arg1_1)
176+
queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None
177+
cos = torch.ops.aten.cos.default(arg1_1)
178+
queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None
179+
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
180+
queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None
181+
queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
182+
queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1)
183+
queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
184+
queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None
185+
add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None
186+
sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None
187+
add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None
188+
return add_1
189+
""",
190+
)
191+
192+
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
193+
def test_bad_fake(self, make_fx_tracing_mode):
194+
torch.library.define(
195+
"_TestOpaqueObject::bad_fake",
196+
"(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor",
197+
tags=torch.Tag.pt2_compliant_tag,
198+
lib=self.lib,
199+
)
200+
201+
def f(q, x):
202+
torch.ops._TestOpaqueObject.bad_fake(x, q)
203+
return x.cos()
204+
205+
def bad_fake1(x, rng_state) -> torch.Tensor:
206+
self.assertTrue(isinstance(rng_state, FakeScriptObject))
207+
out = x.clone()
208+
for i in range(out.numel()):
209+
out.view(-1)[i] += rng_state.rng.random() # bad: accessing attributes
210+
return out
211+
212+
torch.library.register_fake(
213+
"_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib, allow_override=True
214+
)
215+
216+
with self.assertRaisesRegex(
217+
AttributeError,
218+
"Tried to call __getattr__ with attr",
219+
):
220+
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
221+
222+
def bad_fake2(x, rng_state) -> torch.Tensor:
223+
rng_state.rng = "foo"
224+
return torch.empty_like(x)
225+
226+
torch.library.register_fake(
227+
"_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True
228+
)
229+
230+
with self.assertRaisesRegex(
231+
AttributeError,
232+
"Tried to call __setattr__ with attr",
233+
):
234+
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
235+
236+
237+
instantiate_parametrized_tests(TestOpaqueObject)
238+
82239

83240
if __name__ == "__main__":
84241
run_tests()

torch/_library/fake_class_registry.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212

1313

1414
class FakeScriptObject:
15-
def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject):
16-
self.wrapped_obj = wrapped_obj
17-
18-
# The fully qualified name of the class of original script object
19-
self.script_class_name = script_class_name
15+
def __init__(
16+
self, wrapped_obj: Any, script_class_name: str, x: Optional[torch.ScriptObject]
17+
):
18+
# Use object.__setattr__ to bypass our custom __setattr__ during initialization
19+
object.__setattr__(self, "wrapped_obj", wrapped_obj)
20+
object.__setattr__(self, "script_class_name", script_class_name)
2021
try:
2122
with _disable_current_modes():
22-
self.real_obj = copy.deepcopy(x)
23+
real_obj = copy.deepcopy(x)
2324
except RuntimeError as e:
2425
log.warning( # noqa: G200
2526
"Unable to deepcopy the custom object %s due to %s. "
@@ -29,7 +30,31 @@ def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObje
2930
script_class_name,
3031
str(e),
3132
)
32-
self.real_obj = x
33+
real_obj = x
34+
object.__setattr__(self, "real_obj", real_obj)
35+
36+
def __getattribute__(self, name):
37+
try:
38+
return super().__getattribute__(name)
39+
except AttributeError as e:
40+
raise AttributeError(
41+
f"Tried to call __getattr__ with attr '{name}' on a FakeScriptObject, "
42+
"implying that you are calling this inside of a fake kernel. "
43+
"The fake kernel should not depend on the contents of the "
44+
"OpaqueObject at all, so we're erroring out. If you need this"
45+
"functionality, consider creating a custom TorchBind Object instead"
46+
"(but note that this is more difficult)."
47+
) from e
48+
49+
def __setattr__(self, name, value):
50+
raise AttributeError(
51+
f"Tried to call __setattr__ with attr '{name}' on a FakeScriptObject, "
52+
"implying that you are calling this inside of a fake kernel. "
53+
"The fake kernel should not depend on the contents of the "
54+
"OpaqueObject at all, so we're erroring out. If you need this"
55+
"functionality, consider creating a custom TorchBind Object instead"
56+
"(but note that this is more difficult)."
57+
)
3358

3459

3560
class FakeScriptMethod:
@@ -125,7 +150,8 @@ def tracing_with_real(x: torch.ScriptObject) -> bool:
125150

126151

127152
def maybe_to_fake_obj(
128-
fake_mode, x: torch.ScriptObject
153+
fake_mode,
154+
x: Any,
129155
) -> Union[FakeScriptObject, torch.ScriptObject]:
130156
import torch.utils._pytree as pytree
131157
from torch.utils._python_dispatch import _disable_current_modes
@@ -135,13 +161,17 @@ def maybe_to_fake_obj(
135161
if tracing_with_real(x):
136162
return x
137163

138-
from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr
164+
from torch._library.opaque_object import (
165+
FakeOpaqueObject,
166+
is_opaque_type,
167+
OpaqueTypeStr,
168+
)
139169

140-
if str(x._type()) == OpaqueTypeStr:
170+
if x is None or is_opaque_type(type(x)) or str(x._type()) == OpaqueTypeStr:
141171
# In order to make OpaqueObjects truly opaque, the fake kernel should
142172
# not depend on the contents of the OpaqueObject at all.
143-
fake_x = FakeOpaqueObject()
144-
173+
fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), OpaqueTypeStr, None)
174+
return fake_x_wrapped
145175
else:
146176
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
147177
# want to call these ops in surrounding dispatch modes when executing it.
@@ -209,7 +239,8 @@ def maybe_to_fake_obj(
209239
if isinstance(real_attr, torch.ScriptMethod):
210240
method_schema = real_attr.schema # type: ignore[attr-defined]
211241

212-
setattr(
242+
# Bypasses our custom setattr function
243+
object.__setattr__(
213244
fake_x_wrapped,
214245
name,
215246
FakeScriptMethod(fake_x_wrapped, name, method_schema),

torch/fx/experimental/proxy_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torch import SymBool, SymInt, Tensor
4141
from torch._dispatch.python import enable_python_dispatcher
4242
from torch._library.fake_class_registry import FakeScriptObject
43+
from torch._library.opaque_object import is_opaque_type
4344
from torch._logging import trace_structured
4445
from torch._subclasses.fake_impls import fast_detach
4546
from torch._subclasses.fake_tensor import (
@@ -2435,7 +2436,7 @@ def inner_wrap_fake(x: object) -> object:
24352436
hint=x,
24362437
source=source,
24372438
)
2438-
elif isinstance(x, torch.ScriptObject):
2439+
elif isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
24392440
return torch._library.fake_class_registry.maybe_to_fake_obj(
24402441
self.fake_tensor_mode, x
24412442
)

torch/fx/operator_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __getattr__(self, name):
8080
"NoneType": type(None),
8181
"Storage": torch.UntypedStorage,
8282
"t": typing.TypeVar("t"),
83+
"PyObject": Any,
8384
}
8485
for k in dir(typing):
8586
_type_eval_globals[k] = getattr(typing, k)

0 commit comments

Comments
 (0)