Skip to content

Commit 20e9216

Browse files
author
Samantha Andow
authored
Add extremal testing for forward over reverse using decompositions (#818)
* add extremal testing * make nll_loss and cross_entropy only test sane things * make argnum implicit * fix nits * fix devices
1 parent 46bbdce commit 20e9216

File tree

1 file changed

+183
-22
lines changed

1 file changed

+183
-22
lines changed

test/test_ops.py

Lines changed: 183 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
8+
79
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors
810
import torch
911
from torch import Tensor
@@ -1102,6 +1104,27 @@ def test_vjpvmap(self, device, dtype, op):
11021104

11031105
self.assertEqual(result_vjps, expected_vjps)
11041106

1107+
def _compare_jacobians_of_vjp(self, fn, cotangents_and_primals, argnums=None, atol_rtol=None):
1108+
if argnums is None:
1109+
argnums = tuple(range(len(cotangents_and_primals)))
1110+
1111+
def get_vjp(cotangents, *primals):
1112+
_, vjp_fn = vjp(fn, *primals)
1113+
return vjp_fn(cotangents)
1114+
1115+
jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals)
1116+
jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals)
1117+
1118+
# For dtype changing operations, the jacobians have different dtype.
1119+
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
1120+
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
1121+
1122+
if atol_rtol is not None:
1123+
(atol, rtol) = atol_rtol
1124+
self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
1125+
else:
1126+
self.assertEqual(jacobian_jvp, jacobian_vjp)
1127+
11051128
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
11061129
@skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({
11071130
# These are weirdly non-deterministic
@@ -1223,24 +1246,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
12231246
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
12241247
return expected
12251248

1226-
def compare_jacobians(cotangents_and_primals, in_dims, atol_rtol):
1227-
def get_vjp(cotangents, *primals):
1228-
_, vjp_fn = vjp(fn, *primals)
1229-
return vjp_fn(cotangents)
1230-
1231-
jacobian_jvp = jacfwd(get_vjp, in_dims)(*cotangents_and_primals)
1232-
jacobian_vjp = jacrev(get_vjp, in_dims)(*cotangents_and_primals)
1233-
1234-
# For dtype changing operations, the jacobians have different dtype.
1235-
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
1236-
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
1237-
1238-
if atol_rtol is not None:
1239-
(atol, rtol) = atol_rtol
1240-
self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
1241-
else:
1242-
self.assertEqual(jacobian_jvp, jacobian_vjp)
1243-
12441249
# HACK: obviously pytorch should also have the same coverage
12451250
# For things that do have the same coverage, we test that jvp x vjp
12461251
# are the same between PyTorch and functorch. For things that don't,
@@ -1261,16 +1266,172 @@ def is_differentiable(t):
12611266
return isinstance(t, torch.Tensor) and t.dtype == torch.float32
12621267
args = (cotangents, *primals)
12631268
if op.name == 'nn.functional.binary_cross_entropy':
1264-
in_dims = (0, 1) # targets is float32 but isn't differentiable
1265-
atol_rtol = 1.5E-4, 1.3e-06
1269+
argnums = (0, 1) # targets is float32 but isn't differentiable
1270+
atol_rtol = 1.5e-4, 1.3e-06
12661271
else:
1267-
in_dims = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
1272+
argnums = tuple(i for i in range(len(args)) if is_differentiable(args[i]))
12681273
atol_rtol = None
1269-
compare_jacobians(args, in_dims, atol_rtol)
1274+
self._compare_jacobians_of_vjp(fn, args, argnums, atol_rtol)
12701275
else:
12711276
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
12721277
self.assertEqual(result, expected)
12731278

1279+
def _make_extremal_inputs(self, shape, device):
1280+
if shape == None:
1281+
return (None,)
1282+
return (
1283+
torch.full(shape, -1000., device=device),
1284+
torch.zeros(shape, device=device),
1285+
torch.full(shape, 1000., device=device),
1286+
)
1287+
1288+
def _arg_and_kwarg_options(self, args_options, kwargs_options):
1289+
return itertools.product(*args_options, kwargs_options)
1290+
1291+
def test_extremal_numerics_nll_loss(self, device):
1292+
N, C = 3, 4
1293+
d1, d2, d3 = 5, 6, 7
1294+
shapes = (
1295+
((N, C), (N,), (C,)),
1296+
((N, C), (N,), None),
1297+
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
1298+
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
1299+
)
1300+
kwargs_options = ({'ignore_index': 0, 'reduction': 'mean'}, {'reduction': 'sum'}, {'reduction': 'none'}, {})
1301+
for input_shape, target_shape, weight_shape in shapes:
1302+
input_options = self._make_extremal_inputs(input_shape, device)
1303+
for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
1304+
if weight_shape is None:
1305+
weight = None
1306+
else:
1307+
weight = torch.randn(weight_shape, device=device)
1308+
target = torch.randint(0, C, target_shape, device=device)
1309+
target[0] = 1 # since we're ignoring index 0, at least one element must be non-zero
1310+
1311+
fn = functools.partial(torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs)
1312+
result = fn(input)
1313+
cotangents = torch.randn_like(result, device=device)
1314+
self._compare_jacobians_of_vjp(fn, (cotangents, input))
1315+
1316+
def test_extremal_numerics_l1_loss(self, device):
1317+
N, C, H, W = 3, 4, 5, 6
1318+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1319+
kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {})
1320+
for shape in shapes:
1321+
input_options = self._make_extremal_inputs(shape, device)
1322+
target_options = self._make_extremal_inputs(shape, device)
1323+
for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options):
1324+
result = torch.nn.functional.l1_loss(input, target)
1325+
cotangents = torch.randn_like(result, device=device)
1326+
self._compare_jacobians_of_vjp(torch.nn.functional.l1_loss, (cotangents, input, target))
1327+
1328+
def test_extremal_numerics_mse_loss(self, device):
1329+
N, C, H, W = 3, 4, 5, 6
1330+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1331+
kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {})
1332+
for shape in shapes:
1333+
input_options = self._make_extremal_inputs(shape, device)
1334+
target_options = self._make_extremal_inputs(shape, device)
1335+
for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options):
1336+
result = torch.nn.functional.mse_loss(input, target)
1337+
cotangents = torch.randn_like(result, device=device)
1338+
self._compare_jacobians_of_vjp(torch.nn.functional.mse_loss, (cotangents, input, target))
1339+
1340+
def test_extremal_numerics_softmax(self, device):
1341+
N, C, H, W = 3, 4, 5, 6
1342+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1343+
kwargs_options = ({'dim': 1}, {})
1344+
for shape in shapes:
1345+
input_options = self._make_extremal_inputs(shape, device)
1346+
for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
1347+
result = torch.nn.functional.softmax(input)
1348+
cotangents = torch.randn_like(result, device=device)
1349+
self._compare_jacobians_of_vjp(torch.nn.functional.softmax, (cotangents, input))
1350+
1351+
1352+
def test_extremal_numerics_log_softmax(self, device):
1353+
N, C, H, W = 3, 4, 5, 6
1354+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1355+
kwargs_options = ({'dim': 1}, {})
1356+
for shape in shapes:
1357+
input_options = self._make_extremal_inputs(shape, device)
1358+
for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
1359+
result = torch.nn.functional.log_softmax(input)
1360+
cotangents = torch.randn_like(result, device=device)
1361+
self._compare_jacobians_of_vjp(torch.nn.functional.log_softmax, (cotangents, input))
1362+
1363+
def test_extremal_numerics_cross_entropy(self, device):
1364+
N, C = 3, 4
1365+
d1, d2, d3 = 5, 6, 7
1366+
shapes = (
1367+
((N, C), (N,), (C,)),
1368+
((N, C), (N,), None),
1369+
((N, C), (N, C), (C,)),
1370+
((N, C), (N, C), None),
1371+
((C,), (), (C,)),
1372+
((C,), (), None),
1373+
((C,), (C,), (C,)),
1374+
((C,), (C,), None),
1375+
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
1376+
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
1377+
((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)),
1378+
((N, C, d1, d2, d3), (N, C, d1, d2, d3), None),
1379+
)
1380+
for input_shape, target_shape, weight_shape in shapes:
1381+
input_options = self._make_extremal_inputs(input_shape, device)
1382+
kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}]
1383+
if input_shape != target_shape:
1384+
kwargs_options.append({'ignore_index': 0, 'reduction': 'mean'})
1385+
1386+
for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
1387+
if weight_shape is None:
1388+
weight = None
1389+
else:
1390+
weight = torch.randn(weight_shape, device=device)
1391+
1392+
if input_shape == target_shape:
1393+
target = torch.rand(target_shape, device=device)
1394+
elif len(target_shape) == 0:
1395+
target = torch.tensor(1, device=device) # must be non-zero since ignore_index may be 0
1396+
else:
1397+
target = torch.randint(0, C, target_shape, device=device)
1398+
1399+
fn = functools.partial(torch.nn.functional.cross_entropy, target=target, weight=weight, **kwargs)
1400+
result = fn(input)
1401+
cotangents = torch.randn_like(result, device=device)
1402+
self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 1e-5))
1403+
1404+
def test_extremal_numerics_binary_cross_entropy(self, device):
1405+
N, C, H, W = 3, 4, 5, 6
1406+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1407+
for shape in shapes:
1408+
weight_options = self._make_extremal_inputs(shape, device)
1409+
kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}]
1410+
1411+
for weight, kwargs in self._arg_and_kwarg_options((weight_options,), kwargs_options):
1412+
input = torch.rand(shape, device=device)
1413+
target = torch.rand(shape, device=device)
1414+
fn = functools.partial(torch.nn.functional.binary_cross_entropy, target=target, weight=weight, **kwargs)
1415+
result = fn(input)
1416+
cotangents = torch.randn_like(result, device=device)
1417+
self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 2e-5))
1418+
1419+
def test_extremal_numerics_layer_norm(self, device):
1420+
N, C, H, W = 3, 4, 5, 6
1421+
shapes = ((N, C), (N, C, H), (N, C, H, W))
1422+
for shape in shapes:
1423+
input_options = self._make_extremal_inputs(shape, device)
1424+
normalized_shape = shape[1:]
1425+
weight_options = self._make_extremal_inputs(normalized_shape, device)
1426+
bias_options = self._make_extremal_inputs(normalized_shape, device)
1427+
1428+
for input, bias, weight in self._arg_and_kwarg_options((input_options, bias_options, weight_options), ()):
1429+
def fn(input, weight, bias):
1430+
return torch.nn.functional.layer_norm(input, normalized_shape, weight=weight, bias=bias)
1431+
result = fn(input, weight, bias)
1432+
cotangents = torch.randn_like(result, device=device)
1433+
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
1434+
12741435
@ops(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db),
12751436
allowed_dtypes=(torch.float32, torch.double)) # TODO: generalize
12761437
def test_group_norm_backward(self, device, dtype, op):

0 commit comments

Comments
 (0)