11# Owner(s): ["module: custom-operators"]
22
3+ import random
4+
35import torch
46from torch ._dynamo .test_case import run_tests , TestCase
7+ from torch ._library .fake_class_registry import FakeScriptObject
58from torch ._library .opaque_object import register_opaque_type
9+ from torch .fx .experimental .proxy_tensor import make_fx
10+ from torch .testing ._internal .common_utils import (
11+ instantiate_parametrized_tests ,
12+ parametrize ,
13+ )
614
715
816class OpaqueQueue :
@@ -11,24 +19,39 @@ def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> Non
1119 self .queue = queue
1220 self .init_tensor_ = init_tensor_
1321
22+ # For testing purposes
23+ self ._push_counter = 0
24+ self ._pop_counter = 0
25+ self ._size_counter = 0
26+
1427 def push (self , tensor : torch .Tensor ) -> None :
28+ self ._push_counter += 1
1529 self .queue .append (tensor )
1630
1731 def pop (self ) -> torch .Tensor :
32+ self ._pop_counter += 1
1833 if len (self .queue ) > 0 :
1934 return self .queue .pop (0 )
2035 return self .init_tensor_
2136
2237 def size (self ) -> int :
38+ self ._size_counter += 1
2339 return len (self .queue )
2440
2541
42+ class RNGState :
43+ def __init__ (self , seed ):
44+ self .rng = random .Random (seed )
45+
46+
47+ register_opaque_type (OpaqueQueue , "_TestOpaqueObject_OpaqueQueue" )
48+ register_opaque_type (RNGState , "_TestOpaqueObject_RNGState" )
49+
50+
2651class TestOpaqueObject (TestCase ):
2752 def setUp (self ):
2853 self .lib = torch .library .Library ("_TestOpaqueObject" , "FRAGMENT" ) # noqa: TOR901
2954
30- register_opaque_type (OpaqueQueue , "_TestOpaqueObject_OpaqueQueue" )
31-
3255 torch .library .define (
3356 "_TestOpaqueObject::queue_push" ,
3457 "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()" ,
@@ -43,6 +66,10 @@ def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
4366 assert isinstance (queue , OpaqueQueue )
4467 queue .push (b )
4568
69+ @torch .library .register_fake ("_TestOpaqueObject::queue_push" , lib = self .lib )
70+ def push_impl_fake (q : OpaqueQueue , b : torch .Tensor ) -> None :
71+ pass
72+
4673 self .lib .define (
4774 "queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor" ,
4875 )
@@ -53,6 +80,15 @@ def pop_impl(queue: OpaqueQueue) -> torch.Tensor:
5380
5481 self .lib .impl ("queue_pop" , pop_impl , "CompositeExplicitAutograd" )
5582
83+ def pop_impl_fake (q : OpaqueQueue ) -> torch .Tensor :
84+ # This is not accurate since the queue could have tensors that are
85+ # not rank 1
86+ ctx = torch .library .get_ctx ()
87+ u0 = ctx .new_dynamic_size ()
88+ return torch .empty (u0 )
89+
90+ self .lib ._register_fake ("queue_pop" , pop_impl_fake )
91+
5692 @torch .library .custom_op (
5793 "_TestOpaqueObject::queue_size" ,
5894 mutates_args = [],
@@ -61,6 +97,34 @@ def size_impl(queue: OpaqueQueue) -> int:
6197 assert isinstance (queue , OpaqueQueue )
6298 return queue .size ()
6399
100+ @size_impl .register_fake
101+ def size_impl_fake (q : OpaqueQueue ) -> int :
102+ ctx = torch ._custom_op .impl .get_ctx ()
103+ u0 = ctx .new_dynamic_size ()
104+ torch ._check_is_size (u0 )
105+ return u0
106+
107+ torch .library .define (
108+ "_TestOpaqueObject::noisy_inject" ,
109+ "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor" ,
110+ tags = torch .Tag .pt2_compliant_tag ,
111+ lib = self .lib ,
112+ )
113+
114+ @torch .library .impl (
115+ "_TestOpaqueObject::noisy_inject" , "CompositeExplicitAutograd" , lib = self .lib
116+ )
117+ def noisy_inject (x : torch .Tensor , rng_state : RNGState ) -> torch .Tensor :
118+ assert isinstance (rng_state , RNGState )
119+ out = x .clone ()
120+ for i in range (out .numel ()):
121+ out .view (- 1 )[i ] += rng_state .rng .random ()
122+ return out
123+
124+ @torch .library .register_fake ("_TestOpaqueObject::noisy_inject" , lib = self .lib )
125+ def noisy_inject_fake (x : torch .Tensor , obj : RNGState ) -> torch .Tensor :
126+ return torch .empty_like (x )
127+
64128 super ().setUp ()
65129
66130 def tearDown (self ):
@@ -79,6 +143,99 @@ def test_ops(self):
79143 size = torch .ops ._TestOpaqueObject .queue_size (queue )
80144 self .assertEqual (size , 0 )
81145
146+ @parametrize ("make_fx_tracing_mode" , ["fake" , "symbolic" ])
147+ def test_make_fx (self , make_fx_tracing_mode ):
148+ class M (torch .nn .Module ):
149+ def forward (self , queue , x ):
150+ torch .ops ._TestOpaqueObject .queue_push (queue , x .tan ())
151+ torch .ops ._TestOpaqueObject .queue_push (queue , x .cos ())
152+ torch .ops ._TestOpaqueObject .queue_push (queue , x .sin ())
153+ pop1 = torch .ops ._TestOpaqueObject .queue_pop (queue )
154+ size1 = torch .ops ._TestOpaqueObject .queue_size (queue )
155+ pop2 = torch .ops ._TestOpaqueObject .queue_pop (queue )
156+ size2 = torch .ops ._TestOpaqueObject .queue_size (queue )
157+ x_cos = pop1 + size1
158+ x_sin = pop2 - size2
159+ return x_sin + x_cos
160+
161+ q1 = OpaqueQueue ([], torch .empty (0 ).fill_ (- 1 ))
162+ q2 = OpaqueQueue ([], torch .empty (0 ).fill_ (- 1 ))
163+
164+ x = torch .ones (2 , 3 )
165+ gm = make_fx (M (), tracing_mode = make_fx_tracing_mode )(q1 , x )
166+ self .assertTrue (torch .allclose (gm (q1 , x ), M ()(q2 , x )))
167+ self .assertEqual (q1 ._push_counter , 3 )
168+ self .assertEqual (q1 ._pop_counter , 2 )
169+ self .assertEqual (q1 ._size_counter , 2 )
170+ self .assertEqual (q1 .size (), 1 )
171+ self .assertExpectedInline (
172+ gm .code .strip ("\n " ),
173+ """\
174+ def forward(self, arg0_1, arg1_1):
175+ tan = torch.ops.aten.tan.default(arg1_1)
176+ queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None
177+ cos = torch.ops.aten.cos.default(arg1_1)
178+ queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None
179+ sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
180+ queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None
181+ queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
182+ queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1)
183+ queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
184+ queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None
185+ add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None
186+ sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None
187+ add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None
188+ return add_1
189+ """ ,
190+ )
191+
192+ @parametrize ("make_fx_tracing_mode" , ["fake" , "symbolic" ])
193+ def test_bad_fake (self , make_fx_tracing_mode ):
194+ torch .library .define (
195+ "_TestOpaqueObject::bad_fake" ,
196+ "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor" ,
197+ tags = torch .Tag .pt2_compliant_tag ,
198+ lib = self .lib ,
199+ )
200+
201+ def f (q , x ):
202+ torch .ops ._TestOpaqueObject .bad_fake (x , q )
203+ return x .cos ()
204+
205+ def bad_fake1 (x , rng_state ) -> torch .Tensor :
206+ self .assertTrue (isinstance (rng_state , FakeScriptObject ))
207+ out = x .clone ()
208+ for i in range (out .numel ()):
209+ out .view (- 1 )[i ] += rng_state .rng .random () # bad: accessing attributes
210+ return out
211+
212+ torch .library .register_fake (
213+ "_TestOpaqueObject::bad_fake" , bad_fake1 , lib = self .lib , allow_override = True
214+ )
215+
216+ with self .assertRaisesRegex (
217+ AttributeError ,
218+ "Tried to call __getattr__ with attr" ,
219+ ):
220+ make_fx (f , tracing_mode = make_fx_tracing_mode )(RNGState (0 ), torch .ones (3 ))
221+
222+ def bad_fake2 (x , rng_state ) -> torch .Tensor :
223+ rng_state .rng = "foo"
224+ return torch .empty_like (x )
225+
226+ torch .library .register_fake (
227+ "_TestOpaqueObject::bad_fake" , bad_fake2 , lib = self .lib , allow_override = True
228+ )
229+
230+ with self .assertRaisesRegex (
231+ AttributeError ,
232+ "Tried to call __setattr__ with attr" ,
233+ ):
234+ make_fx (f , tracing_mode = make_fx_tracing_mode )(RNGState (0 ), torch .ones (3 ))
235+
236+
237+ instantiate_parametrized_tests (TestOpaqueObject )
238+
82239
83240if __name__ == "__main__" :
84241 run_tests ()
0 commit comments