Skip to content

Commit eace82a

Browse files
authored
Remove removable randomness skips (#953)
The jvp and vjp transforms should not change randomness behavior; e.g. dropout under vjp and with regular PyTorch autograd should produce the same values. vmap however does change randomness behavior. This PR removes a bunch of randomness skips from jvp and vjp only tests and also fixes our implementation of dropout such that it maintains the above property. Test Plan: - run tests
1 parent f4a3d5a commit eace82a

File tree

2 files changed

+13
-42
lines changed

2 files changed

+13
-42
lines changed

functorch/csrc/PyTorchOperatorHacks.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,12 @@ Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
295295
// NB: THIS WAS CHANGED FROM THE ORIGINAL
296296
Tensor noise;
297297
if (feature_dropout) {
298-
auto prob = make_feature_noise(input);
299-
prob.fill_(1 - p);
300-
noise = at::bernoulli(prob);
298+
auto empty = make_feature_noise(input);
299+
noise = at::bernoulli(empty, 1 - p);
301300
} else {
302-
// NB: it is important that this is at::full and not at::full_like
303-
auto prob = at::full({}, 1 - p, input.options()).expand(input.sizes());
304-
noise = at::bernoulli(prob);
301+
// NB: it is important that this is at::empty and not at::empty_like
302+
auto empty = at::empty({}, input.options()).expand(input.sizes());
303+
noise = at::bernoulli(empty, 1 - p);
305304
}
306305

307306
if (alpha_dropout) {

test/test_ops.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -299,27 +299,17 @@ def is_inplace(op, variant):
299299

300300

301301
vjp_fail = {
302-
skip('nn.functional.dropout'), # randomness testing artifact
303-
skip('nn.functional.rrelu'), # randomness testing artifact
304-
skip('bernoulli'), # randomness testing artifact
305-
skip('normal', ''), # randomness testing artifact
306-
skip('normal', 'number_mean'), # randomness testing artifact
307302
xfail('tensor_split'),
308303
xfail('to_sparse'),
309304
xfail('nn.functional.ctc_loss'),
310-
skip('nn.functional.feature_alpha_dropout', 'with_train'), # fails on cuda, runs okay on cpu
311-
skip('nn.functional.feature_alpha_dropout', 'without_train'), # fails on cuda, runs okay on cpu
312305
skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu
313306
skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu
314-
skip('nn.functional.dropout2d', ''), # fails on cuda, runs okay on cpu
315307
}
316308

317309

318310
class TestOperators(TestCase):
319311
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
320312
@skipOps('TestOperators', 'test_grad', vjp_fail.union({
321-
skip('nn.functional.fractional_max_pool2d'), # fails on cuda, runs okay on cpu
322-
skip('nn.functional.fractional_max_pool3d'), # fails on cuda, runs okay on cpu
323313
xfail('linalg.eig'), # diagonal_scatter does not support complex
324314
}))
325315
@opsToleranceOverride('TestOperators', 'test_grad', (
@@ -368,16 +358,9 @@ def wrapped_fn(*args, **kwargs):
368358

369359
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
370360
@skipOps('TestOperators', 'test_jvp', set({
371-
skip('nn.functional.dropout'), # randomness testing artifact; not actually a problem
372-
skip('nn.functional.rrelu'), # randomness testing artifact; not actually a problem
373-
skip('nn.functional.fractional_max_pool2d'), # fails on cuda, runs okay on cpu
374-
skip('nn.functional.fractional_max_pool3d'), # fails on cuda, runs okay on cpu
375361
skip('nn.functional.max_pool1d'), # fails on cpu, runs okay on cuda
376-
skip('nn.functional.feature_alpha_dropout', 'with_train'), # fails on cuda, runs okay on cpu
377-
skip('nn.functional.feature_alpha_dropout', 'without_train'), # fails on cuda, runs okay on cpu
378362
skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu
379363
skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu
380-
skip('nn.functional.dropout2d', ''), # fails on cuda, runs okay on cpu
381364
382365
# =============================================
383366
# NB: The above failures also fail using PyTorch core's
@@ -389,8 +372,6 @@ def wrapped_fn(*args, **kwargs):
389372
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
390373
xfail('tensor_split'),
391374
392-
skip('bernoulli'), # cuda set seed randomness issues
393-
394375
# BUG: runs and produces numerical differences
395376
skip('nn.functional.max_unpool1d'), # fails everywhere except on mac
396377
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
@@ -435,12 +416,7 @@ def test_jvp(self, device, dtype, op):
435416

436417
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
437418
@skipOps('TestOperators', 'test_vjp', vjp_fail.union({
438-
skip('nn.functional.fractional_max_pool2d'), # fails on cpu, runs okay on cuda
439-
skip('nn.functional.fractional_max_pool3d'), # fails on cpu, runs okay on cuda
440-
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
441419
xfail('pca_lowrank', ''),
442-
xfail('nn.functional.dropout2d', ''),
443-
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
444420
xfail('svd_lowrank', ''),
445421
}))
446422
@opsToleranceOverride('TestOperators', 'test_vjp', (
@@ -484,8 +460,6 @@ def _test(_op):
484460
@skipOps('TestOperators', 'test_vjpvjp', vjp_fail.union({
485461
skip('nn.functional.max_unpool1d'), # Flaky
486462
skip('nn.functional.max_unpool2d'), # Flaky
487-
skip('nn.functional.fractional_max_pool2d'), # randomness
488-
skip('nn.functional.fractional_max_pool3d'), # randomness
489463
}))
490464
@opsToleranceOverride('TestOperators', 'test_vjpvjp', (
491465
tol1('nn.functional.conv_transpose3d',
@@ -576,7 +550,11 @@ def vjp_of_vjp(*args_and_cotangents):
576550
skip('bernoulli'), # randomness
577551
skip('normal', ''), # randomness
578552
skip('normal', 'number_mean'), # randomness
579-
xfail('nn.functional.dropout'), # randomness
553+
skip('nn.functional.rrelu'), # randomness
554+
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
555+
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
556+
skip('nn.functional.dropout'), # randomness
557+
skip('nn.functional.dropout2d'), # randomness
580558
xfail('as_strided'), # as_strided is too wild for us to support, wontfix
581559
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
582560
xfail('masked_scatter'), # dynamic
@@ -934,6 +912,9 @@ def test():
934912
skip('bernoulli', ''), # vjpvmap testing can't handle randomness
935913
skip('normal', ''), # vjpvmap testing can't handle randomness
936914
skip('normal', 'number_mean'), # vjpvmap testing can't handle randomness
915+
skip('nn.functional.rrelu'), # randomness
916+
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
917+
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
937918
938919
# fallback path doesn't work
939920
# All of the following are bugs and need to be fixed
@@ -951,8 +932,6 @@ def test():
951932
xfail('nn.functional.dropout2d', ''),
952933
xfail('svd_lowrank', ''),
953934
xfail('pca_lowrank', ''),
954-
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
955-
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
956935
xfail('clamp'),
957936
# something weird happening with channels_last
958937
xfail('bfloat16'),
@@ -1025,10 +1004,6 @@ def get_vjp(cotangents, *primals):
10251004

10261005
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
10271006
@skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({
1028-
# These are weirdly non-deterministic
1029-
skip('nn.functional.fractional_max_pool2d'), # Random
1030-
skip('nn.functional.fractional_max_pool3d'), # Random
1031-
10321007
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
10331008
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
10341009
xfail('normal', ''),
@@ -1049,11 +1024,8 @@ def get_vjp(cotangents, *primals):
10491024
xfail('nn.functional.softmin', 'with_dtype'),
10501025
xfail('renorm', ''),
10511026
xfail('symeig', ''),
1052-
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
10531027
skip('nn.functional.kl_div', ''), # will pass when linux cpu binaries update
10541028
xfail('pca_lowrank', ''),
1055-
xfail('nn.functional.dropout2d', ''),
1056-
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
10571029
xfail('svd_lowrank', ''),
10581030
xfail('nn.functional.multilabel_margin_loss', ''),
10591031
xfail('nn.functional.multilabel_soft_margin_loss', ''),

0 commit comments

Comments
 (0)