@@ -82,7 +82,33 @@ def is_differentiable_arg(arg):
82
82
def normalize_op_input_output2 (f , args , kwargs , output_process_fn_grad = None , requires_grad = True ):
83
83
flat_args , args_spec = tree_flatten (args )
84
84
diff_argnums = tuple (i for i , arg in enumerate (flat_args ) if diff_arg (arg , requires_grad = requires_grad ))
85
+ assert len (diff_argnums ) > 0
86
+ primals = tuple (flat_args [i ] for i in diff_argnums )
87
+
88
+ @functools .wraps (f )
89
+ def wrapped (* primals ):
90
+ _args = list (flat_args )
91
+ for num , arg in zip (diff_argnums , primals ):
92
+ _args [num ] = arg
93
+ _args = tree_unflatten (_args , args_spec )
94
+ result = f (* _args , ** kwargs )
95
+ if output_process_fn_grad is not None :
96
+ result = output_process_fn_grad (result )
97
+ if isinstance (result , tuple ):
98
+ # TODO: Remove the following hack for namedtuples
99
+ result = tuple (result )
100
+ result = tuple (r for r in result if torch .is_floating_point (r ))
101
+ assert len (result ) > 0
102
+ return result
103
+ return wrapped , primals
104
+
85
105
106
+ # TODO: consolidate with normalize_op_input_output2
107
+ def normalize_op_input_output3 (f , args , kwargs , sample_args , output_process_fn_grad = None ):
108
+ flat_args , args_spec = tree_flatten (args )
109
+ flat_sample_args , _ = tree_flatten (sample_args )
110
+ diff_argnums = tuple (i for i , (arg , sample ) in enumerate (zip (flat_args , flat_sample_args ))
111
+ if diff_arg (sample , requires_grad = True ))
86
112
assert len (diff_argnums ) > 0
87
113
primals = tuple (flat_args [i ] for i in diff_argnums )
88
114
@@ -128,10 +154,43 @@ def ref_jvp(f, primals, tangents):
128
154
primals_out , tangents_out = zip (* (fwAD .unpack_dual (d ) for d in result_duals ))
129
155
return tree_unflatten (primals_out , spec ), tree_unflatten (tangents_out , spec )
130
156
131
- # Returns a new function g(*args, *cotangents) that computes vjps and
132
- # sample (*args, *cotangents)
133
157
158
+ def get_sample_cotangents (f , sample ):
159
+ fn , primals = normalize_op_input_output (f , sample )
160
+ output = fn (* primals )
161
+ if isinstance (output , tuple ):
162
+ # TODO: Remove the following hack for torch.return_types
163
+ output = tuple (output )
164
+ return tree_map (torch .randn_like , output )
165
+
166
+
167
+ # returns a new function g(*args, *cotangents)
168
+ # that computes vjps and (*args, cotangents)
169
+ def get_vjp_fn_and_args_with_cotangents (f , sample , cotangents ):
170
+ args = tuple ([sample .input ] + list (sample .args ))
171
+ kwargs = sample .kwargs
172
+ flat_args , args_spec = tree_flatten (args )
173
+ flat_cotangents , cotangents_spec = tree_flatten (cotangents )
174
+
175
+ @functools .wraps (f )
176
+ def wrapped (* args ):
177
+ assert len (args ) == len (flat_args ) + len (flat_cotangents )
178
+ actual_args = args [:len (flat_args )]
179
+ cotangents = args [len (flat_args ):]
180
+ actual_args = tree_unflatten (actual_args , args_spec )
181
+ cotangents = tree_unflatten (cotangents , cotangents_spec )
182
+
183
+ fn , primals = normalize_op_input_output3 (f , actual_args , kwargs ,
184
+ flat_args ,
185
+ sample .output_process_fn_grad )
186
+ _ , vjp_fn = vjp (fn , * primals )
187
+ return vjp_fn (cotangents )
188
+
189
+ return wrapped , tuple (flat_args + flat_cotangents )
134
190
191
+
192
+ # Returns a new function g(*args, *cotangents) that computes vjps and
193
+ # sample (*args, *cotangents)
135
194
def get_vjpfull_variant (f , sample ):
136
195
fn , primals = normalize_op_input_output (f , sample )
137
196
result = fn (* primals )
@@ -438,6 +497,10 @@ def vjp_of_vjp(*args_and_cotangents):
438
497
get_fallback_and_vmap_exhaustive (vjp_of_vjp , args_and_cotangents , {}):
439
498
self .assertEqual (loop_out , batched_out , atol = 1e-4 , rtol = 1e-4 )
440
499
vmapvjp_fail = vjp_fail .union ({
500
+ # The following are not bugs and are expected behavior
501
+ xfail ('fill_' ), # Not possible, wontfix
502
+ xfail ('masked_select' ), # Not possible due to dynamic shapes
503
+
441
504
# All of the following are bugs and need to be fixed
442
505
xfail ('diag_embed' ),
443
506
xfail ('eig' ),
@@ -447,7 +510,6 @@ def vjp_of_vjp(*args_and_cotangents):
447
510
xfail ('fft.rfft' ),
448
511
xfail ('fft.rfft' ),
449
512
xfail ('fft.rfftn' ),
450
- xfail ('cdist' ),
451
513
xfail ('fmax' ),
452
514
xfail ('fmin' ),
453
515
xfail ('index_copy' ),
@@ -487,6 +549,9 @@ def vjp_of_vjp(*args_and_cotangents):
487
549
xfail ('nn.functional.fractional_max_pool3d' ),
488
550
xfail ('as_strided' ),
489
551
xfail ('nn.functional.fractional_max_pool2d' ),
552
+ xfail ('__getitem__' ),
553
+ xfail ('index_put' ),
554
+ xfail ('lu_solve' ),
490
555
})
491
556
492
557
@ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
@@ -504,7 +569,8 @@ def test_vmapvjp(self, device, dtype, op):
504
569
return
505
570
506
571
for sample in samples :
507
- fn , args = get_vjpfull_variant (op , sample )
572
+ cotangents = get_sample_cotangents (op , sample )
573
+ fn , args = get_vjp_fn_and_args_with_cotangents (op , sample , cotangents )
508
574
for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}):
509
575
self .assertEqual (loop_out , batched_out , atol = 1e-4 , rtol = 1e-4 )
510
576
@@ -796,13 +862,17 @@ def test_vmapvjp_has_batch_rule(self, device, dtype, op):
796
862
797
863
def test ():
798
864
for sample in samples :
799
- fn , args = get_vjpfull_variant (op , sample )
800
- for _ in get_fallback_and_vmap_exhaustive (fn , args , {}, compute_loop_out = False ):
865
+ cotangents = get_sample_cotangents (op , sample )
866
+ fn , args = get_vjp_fn_and_args_with_cotangents (op , sample , cotangents )
867
+ for loop_out , batched_out in get_fallback_and_vmap_exhaustive (
868
+ fn , args , {}, compute_loop_out = False ):
801
869
pass
802
870
for a_op in op .aliases :
803
- fn , args = get_vjpfull_variant (a_op , sample )
804
- for _ in get_fallback_and_vmap_exhaustive (fn , args , {}, compute_loop_out = False ):
871
+ fn , args = get_vjp_fn_and_args_with_cotangents (a_op , sample , cotangents )
872
+ for loop_out , batched_out in get_fallback_and_vmap_exhaustive (
873
+ fn , args , {}, compute_loop_out = False ):
805
874
pass
875
+
806
876
check_vmap_fallback (self , test , op , dry_run = False )
807
877
808
878
@ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
@@ -823,7 +893,6 @@ def test():
823
893
xfail ('linalg.multi_dot' ),
824
894
xfail ('vstack' ),
825
895
xfail ('nn.functional.batch_norm' ),
826
- xfail ('cdist' ),
827
896
xfail ('lu_solve' ),
828
897
xfail ('lu_unpack' ),
829
898
xfail ('matrix_exp' ),
0 commit comments