Skip to content

Commit b107bab

Browse files
committed
Update op db and fix failing tests
1 parent 1905a2c commit b107bab

File tree

4 files changed

+133
-15
lines changed

4 files changed

+133
-15
lines changed

test/functorch_lagging_op_db.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# To achieve this, we keep our OpInfo library behind that of Pytorch's and
1414
# we periodically update our OpInfo library by regenerating this file
1515
_functorch_lagging_meta = {
16+
('H', ''),
17+
('T', ''),
1618
('__getitem__', ''),
1719
('__radd__', ''),
1820
('__rand__', ''),
@@ -24,6 +26,10 @@
2426
('__rpow__', ''),
2527
('__rsub__', ''),
2628
('__rxor__', ''),
29+
('_masked.amax', ''),
30+
('_masked.amin', ''),
31+
('_masked.prod', ''),
32+
('_masked.sum', ''),
2733
('abs', ''),
2834
('acos', ''),
2935
('acosh', ''),
@@ -50,19 +56,26 @@
5056
('atan2', ''),
5157
('atanh', ''),
5258
('baddbmm', ''),
59+
('bfloat16', ''),
60+
('bincount', ''),
5361
('bitwise_and', ''),
5462
('bitwise_left_shift', ''),
5563
('bitwise_not', ''),
5664
('bitwise_right_shift', ''),
5765
('block_diag', ''),
5866
('bmm', ''),
67+
('bool', ''),
5968
('broadcast_tensors', ''),
6069
('broadcast_to', ''),
70+
('bucketize', ''),
71+
('byte', ''),
6172
('cat', ''),
6273
('cdist', ''),
6374
('ceil', ''),
75+
('char', ''),
6476
('cholesky', ''),
6577
('cholesky_inverse', ''),
78+
('cholesky_solve', ''),
6679
('chunk', ''),
6780
('clamp', ''),
6881
('clamp', 'scalar'),
@@ -94,10 +107,12 @@
94107
('div', 'no_rounding_mode'),
95108
('div', 'trunc_rounding'),
96109
('dot', ''),
110+
('double', ''),
97111
('dsplit', ''),
98112
('dstack', ''),
99113
('eig', ''),
100114
('einsum', ''),
115+
('empty_like', ''),
101116
('eq', ''),
102117
('erf', ''),
103118
('erfc', ''),
@@ -108,19 +123,28 @@
108123
('expand_as', ''),
109124
('expm1', ''),
110125
('fft.fft', ''),
126+
('fft.fft2', ''),
111127
('fft.fftn', ''),
112128
('fft.hfft', ''),
129+
('fft.hfft2', ''),
130+
('fft.hfftn', ''),
113131
('fft.ifft', ''),
132+
('fft.ifft2', ''),
114133
('fft.ifftn', ''),
115134
('fft.ihfft', ''),
135+
('fft.ihfft2', ''),
136+
('fft.ihfftn', ''),
116137
('fft.irfft', ''),
138+
('fft.irfft2', ''),
117139
('fft.irfftn', ''),
118140
('fft.rfft', ''),
141+
('fft.rfft2', ''),
119142
('fft.rfftn', ''),
120143
('fill_', ''),
121144
('flip', ''),
122145
('fliplr', ''),
123146
('flipud', ''),
147+
('float', ''),
124148
('float_power', ''),
125149
('floor', ''),
126150
('floor_divide', ''),
@@ -130,12 +154,15 @@
130154
('fmod', 'autodiffed'),
131155
('frac', ''),
132156
('frexp', ''),
157+
('full_like', ''),
133158
('gather', ''),
134159
('ge', ''),
135160
('geqrf', ''),
136161
('gradient', ''),
137162
('gt', ''),
163+
('half', ''),
138164
('histogram', ''),
165+
('histogramdd', ''),
139166
('hsplit', ''),
140167
('hstack', ''),
141168
('hypot', ''),
@@ -151,8 +178,15 @@
151178
('index_put', ''),
152179
('index_select', ''),
153180
('inner', ''),
181+
('int', ''),
154182
('inverse', ''),
183+
('isfinite', ''),
155184
('isin', ''),
185+
('isinf', ''),
186+
('isnan', ''),
187+
('isneginf', ''),
188+
('isposinf', ''),
189+
('isreal', ''),
156190
('kron', ''),
157191
('kthvalue', ''),
158192
('le', ''),
@@ -171,6 +205,7 @@
171205
('linalg.inv', ''),
172206
('linalg.inv_ex', ''),
173207
('linalg.lstsq', ''),
208+
('linalg.lstsq', 'grad_oriented'),
174209
('linalg.matrix_norm', ''),
175210
('linalg.matrix_power', ''),
176211
('linalg.matrix_rank', ''),
@@ -179,6 +214,7 @@
179214
('linalg.norm', ''),
180215
('linalg.pinv', ''),
181216
('linalg.pinv', 'hermitian'),
217+
('linalg.pinv', 'singular'),
182218
('linalg.qr', ''),
183219
('linalg.slogdet', ''),
184220
('linalg.solve', ''),
@@ -199,10 +235,13 @@
199235
('logical_not', ''),
200236
('logit', ''),
201237
('logsumexp', ''),
238+
('long', ''),
202239
('lt', ''),
203240
('lu', ''),
204241
('lu_solve', ''),
205242
('lu_unpack', ''),
243+
('mH', ''),
244+
('mT', ''),
206245
('masked_fill', ''),
207246
('masked_scatter', ''),
208247
('masked_select', ''),
@@ -238,14 +277,19 @@
238277
('ne', ''),
239278
('neg', ''),
240279
('nextafter', ''),
280+
('nn.functional.adaptive_avg_pool1d', ''),
241281
('nn.functional.adaptive_avg_pool2d', ''),
282+
('nn.functional.adaptive_avg_pool3d', ''),
283+
('nn.functional.avg_pool1d', ''),
242284
('nn.functional.avg_pool2d', ''),
285+
('nn.functional.avg_pool3d', ''),
243286
('nn.functional.batch_norm', ''),
244287
('nn.functional.batch_norm', 'without_cudnn'),
245288
('nn.functional.conv2d', ''),
246289
('nn.functional.conv_transpose2d', ''),
247290
('nn.functional.cosine_similarity', ''),
248291
('nn.functional.dropout', ''),
292+
('nn.functional.embedding', ''),
249293
('nn.functional.gelu', ''),
250294
('nn.functional.grid_sample', ''),
251295
('nn.functional.hardshrink', ''),
@@ -270,6 +314,9 @@
270314
('nn.functional.pad', 'constant'),
271315
('nn.functional.pad', 'reflect'),
272316
('nn.functional.pad', 'replicate'),
317+
('nn.functional.pairwise_distance', ''),
318+
('nn.functional.pixel_shuffle', ''),
319+
('nn.functional.pixel_unshuffle', ''),
273320
('nn.functional.relu', ''),
274321
('nn.functional.relu6', ''),
275322
('nn.functional.softplus', ''),
@@ -278,6 +325,7 @@
278325
('norm', 'fro'),
279326
('norm', 'inf'),
280327
('norm', 'nuc'),
328+
('ones_like', ''),
281329
('ormqr', ''),
282330
('outer', ''),
283331
('permute', ''),
@@ -295,6 +343,7 @@
295343
('qr', ''),
296344
('quantile', ''),
297345
('rad2deg', ''),
346+
('randn_like', ''),
298347
('ravel', ''),
299348
('real', ''),
300349
('reciprocal', ''),
@@ -319,6 +368,7 @@
319368
('scatter_add', ''),
320369
('select', ''),
321370
('sgn', ''),
371+
('short', ''),
322372
('sigmoid', ''),
323373
('sign', ''),
324374
('signbit', ''),
@@ -386,6 +436,7 @@
386436
('where', ''),
387437
('xlogy', ''),
388438
('zero_', ''),
439+
('zeros_like', ''),
389440
}
390441

391442

test/test_ops.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,7 @@ def vjp_of_vjp(*args_and_cotangents):
304304
for loop_out, batched_out in \
305305
get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}):
306306
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
307-
308-
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
309-
@skipOps('TestOperators', 'test_vmapvjp', vjp_fail.union({
307+
vmapvjp_fail = vjp_fail.union({
310308
# All of the following are bugs and need to be fixed
311309
xfail('clamp', ''),
312310
xfail('diag_embed'),
@@ -360,7 +358,16 @@ def vjp_of_vjp(*args_and_cotangents):
360358
xfail('nanmean'),
361359
xfail('block_diag'),
362360
xfail('nn.functional.dropout'),
363-
}))
361+
xfail('double'),
362+
xfail('fft.fft2'),
363+
xfail('fft.ifft2'),
364+
xfail('fft.ihfft2'),
365+
xfail('fft.ihfftn'),
366+
xfail('fft.rfft2'),
367+
xfail('_masked.prod'), # calls aten::item
368+
})
369+
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
370+
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
364371
def test_vmapvjp(self, device, dtype, op):
365372
# These are too annoying to put into the list above
366373
if op.name in {'nn.functional.linear', 'nn.functional.conv2d'}:
@@ -382,7 +389,7 @@ def test_vmapvjp(self, device, dtype, op):
382389
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
383390

384391
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
385-
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', {
392+
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
386393
xfail('nn.functional.pad', 'constant'),
387394
xfail('view_as_complex'),
388395
xfail('__getitem__'),
@@ -493,7 +500,22 @@ def test_vmapvjp(self, device, dtype, op):
493500
xfail('block_diag'),
494501
xfail('nn.functional.dropout'),
495502
xfail('nn.functional.batch_norm'),
496-
})
503+
xfail('_masked.amax'),
504+
xfail('_masked.amin'),
505+
xfail('_masked.sum'),
506+
xfail('_masked.prod'),
507+
xfail('cholesky_solve'),
508+
xfail('double'),
509+
xfail('fft.fft2'),
510+
xfail('fft.ifft2'),
511+
xfail('fft.ihfft2'),
512+
xfail('fft.ihfftn'),
513+
xfail('fft.rfft2'),
514+
xfail('nn.functional.adaptive_avg_pool1d'),
515+
xfail('nn.functional.adaptive_avg_pool3d'),
516+
xfail('nn.functional.avg_pool3d'),
517+
xfail('nn.functional.embedding'),
518+
}))
497519
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
498520
# These are too annoying to put into the list above
499521
if op.name in {'nn.functional.linear', 'nn.functional.conv2d'}:
@@ -522,6 +544,8 @@ def test():
522544

523545
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
524546
@skipOps('TestOperators', 'test_vjpvmap', vjp_fail.union({
547+
# fallback path doesn't work
548+
xfail('H'),
525549
# All of the following are bugs and need to be fixed
526550
xfail('__getitem__'),
527551
xfail('clamp', ''),
@@ -543,6 +567,7 @@ def test():
543567
xfail('lu_unpack'),
544568
xfail('matrix_exp'),
545569
xfail('view_as_complex'),
570+
xfail('double'),
546571
}))
547572
def test_vjpvmap(self, device, dtype, op):
548573
# NB: there is no vjpvmap_has_batch_rule test because that is almost

test/test_pythonkey.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class TestPythonKeyOperatorsOpInfo(TestCase):
197197
xfail('nn.functional.dropout'),
198198
xfail('linalg.eigvals'),
199199
xfail('nn.functional.pad', 'circular'),
200+
xfail('empty_like'), # randomness
201+
xfail('randn_like'), # randomness
200202
})
201203
def test_make_fx_exhaustive(self, device, dtype, op):
202204

0 commit comments

Comments
 (0)