4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import itertools
8
+
7
9
from torch .testing ._internal .common_utils import TestCase , run_tests , is_iterable_of_tensors
8
10
import torch
9
11
from torch import Tensor
@@ -1102,6 +1104,27 @@ def test_vjpvmap(self, device, dtype, op):
1102
1104
1103
1105
self .assertEqual (result_vjps , expected_vjps )
1104
1106
1107
+ def _compare_jacobians_of_vjp (self , fn , cotangents_and_primals , argnums = None , atol_rtol = None ):
1108
+ if argnums is None :
1109
+ argnums = tuple (range (len (cotangents_and_primals )))
1110
+
1111
+ def get_vjp (cotangents , * primals ):
1112
+ _ , vjp_fn = vjp (fn , * primals )
1113
+ return vjp_fn (cotangents )
1114
+
1115
+ jacobian_jvp = jacfwd (get_vjp , argnums )(* cotangents_and_primals )
1116
+ jacobian_vjp = jacrev (get_vjp , argnums )(* cotangents_and_primals )
1117
+
1118
+ # For dtype changing operations, the jacobians have different dtype.
1119
+ jacobian_jvp = tree_map (lambda x : x .to (torch .float ), jacobian_jvp )
1120
+ jacobian_vjp = tree_map (lambda x : x .to (torch .float ), jacobian_vjp )
1121
+
1122
+ if atol_rtol is not None :
1123
+ (atol , rtol ) = atol_rtol
1124
+ self .assertEqual (jacobian_jvp , jacobian_vjp , atol = atol , rtol = rtol )
1125
+ else :
1126
+ self .assertEqual (jacobian_jvp , jacobian_vjp )
1127
+
1105
1128
@ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
1106
1129
@skipOps ('TestOperators' , 'test_jvpvjp' , vjp_fail .union ({
1107
1130
# These are weirdly non-deterministic
@@ -1223,24 +1246,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
1223
1246
expected = (tree_unflatten (primals_out , spec ), tree_unflatten (tangents_out , spec ))
1224
1247
return expected
1225
1248
1226
- def compare_jacobians (cotangents_and_primals , in_dims , atol_rtol ):
1227
- def get_vjp (cotangents , * primals ):
1228
- _ , vjp_fn = vjp (fn , * primals )
1229
- return vjp_fn (cotangents )
1230
-
1231
- jacobian_jvp = jacfwd (get_vjp , in_dims )(* cotangents_and_primals )
1232
- jacobian_vjp = jacrev (get_vjp , in_dims )(* cotangents_and_primals )
1233
-
1234
- # For dtype changing operations, the jacobians have different dtype.
1235
- jacobian_jvp = tree_map (lambda x : x .to (torch .float ), jacobian_jvp )
1236
- jacobian_vjp = tree_map (lambda x : x .to (torch .float ), jacobian_vjp )
1237
-
1238
- if atol_rtol is not None :
1239
- (atol , rtol ) = atol_rtol
1240
- self .assertEqual (jacobian_jvp , jacobian_vjp , atol = atol , rtol = rtol )
1241
- else :
1242
- self .assertEqual (jacobian_jvp , jacobian_vjp )
1243
-
1244
1249
# HACK: obviously pytorch should also have the same coverage
1245
1250
# For things that do have the same coverage, we test that jvp x vjp
1246
1251
# are the same between PyTorch and functorch. For things that don't,
@@ -1261,16 +1266,172 @@ def is_differentiable(t):
1261
1266
return isinstance (t , torch .Tensor ) and t .dtype == torch .float32
1262
1267
args = (cotangents , * primals )
1263
1268
if op .name == 'nn.functional.binary_cross_entropy' :
1264
- in_dims = (0 , 1 ) # targets is float32 but isn't differentiable
1265
- atol_rtol = 1.5E -4 , 1.3e-06
1269
+ argnums = (0 , 1 ) # targets is float32 but isn't differentiable
1270
+ atol_rtol = 1.5e -4 , 1.3e-06
1266
1271
else :
1267
- in_dims = tuple (i for i in range (len (args )) if is_differentiable (args [i ]))
1272
+ argnums = tuple (i for i in range (len (args )) if is_differentiable (args [i ]))
1268
1273
atol_rtol = None
1269
- compare_jacobians ( args , in_dims , atol_rtol )
1274
+ self . _compare_jacobians_of_vjp ( fn , args , argnums , atol_rtol )
1270
1275
else :
1271
1276
expected = reference (primals , cotangents , primals_tangents , cotangents_tangents )
1272
1277
self .assertEqual (result , expected )
1273
1278
1279
+ def _make_extremal_inputs (self , shape , device ):
1280
+ if shape == None :
1281
+ return (None ,)
1282
+ return (
1283
+ torch .full (shape , - 1000. , device = device ),
1284
+ torch .zeros (shape , device = device ),
1285
+ torch .full (shape , 1000. , device = device ),
1286
+ )
1287
+
1288
+ def _arg_and_kwarg_options (self , args_options , kwargs_options ):
1289
+ return itertools .product (* args_options , kwargs_options )
1290
+
1291
+ def test_extremal_numerics_nll_loss (self , device ):
1292
+ N , C = 3 , 4
1293
+ d1 , d2 , d3 = 5 , 6 , 7
1294
+ shapes = (
1295
+ ((N , C ), (N ,), (C ,)),
1296
+ ((N , C ), (N ,), None ),
1297
+ ((N , C , d1 , d2 , d3 ), (N , d1 , d2 , d3 ), (C ,)),
1298
+ ((N , C , d1 , d2 , d3 ), (N , d1 , d2 , d3 ), None ),
1299
+ )
1300
+ kwargs_options = ({'ignore_index' : 0 , 'reduction' : 'mean' }, {'reduction' : 'sum' }, {'reduction' : 'none' }, {})
1301
+ for input_shape , target_shape , weight_shape in shapes :
1302
+ input_options = self ._make_extremal_inputs (input_shape , device )
1303
+ for input , kwargs in self ._arg_and_kwarg_options ((input_options ,), kwargs_options ):
1304
+ if weight_shape is None :
1305
+ weight = None
1306
+ else :
1307
+ weight = torch .randn (weight_shape , device = device )
1308
+ target = torch .randint (0 , C , target_shape , device = device )
1309
+ target [0 ] = 1 # since we're ignoring index 0, at least one element must be non-zero
1310
+
1311
+ fn = functools .partial (torch .nn .functional .nll_loss , target = target , weight = weight , ** kwargs )
1312
+ result = fn (input )
1313
+ cotangents = torch .randn_like (result , device = device )
1314
+ self ._compare_jacobians_of_vjp (fn , (cotangents , input ))
1315
+
1316
+ def test_extremal_numerics_l1_loss (self , device ):
1317
+ N , C , H , W = 3 , 4 , 5 , 6
1318
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1319
+ kwargs_options = ({'reduction' : 'sum' }, {'reduction' : 'none' }, {})
1320
+ for shape in shapes :
1321
+ input_options = self ._make_extremal_inputs (shape , device )
1322
+ target_options = self ._make_extremal_inputs (shape , device )
1323
+ for input , target , kwargs in self ._arg_and_kwarg_options ((input_options , target_options ), kwargs_options ):
1324
+ result = torch .nn .functional .l1_loss (input , target )
1325
+ cotangents = torch .randn_like (result , device = device )
1326
+ self ._compare_jacobians_of_vjp (torch .nn .functional .l1_loss , (cotangents , input , target ))
1327
+
1328
+ def test_extremal_numerics_mse_loss (self , device ):
1329
+ N , C , H , W = 3 , 4 , 5 , 6
1330
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1331
+ kwargs_options = ({'reduction' : 'sum' }, {'reduction' : 'none' }, {})
1332
+ for shape in shapes :
1333
+ input_options = self ._make_extremal_inputs (shape , device )
1334
+ target_options = self ._make_extremal_inputs (shape , device )
1335
+ for input , target , kwargs in self ._arg_and_kwarg_options ((input_options , target_options ), kwargs_options ):
1336
+ result = torch .nn .functional .mse_loss (input , target )
1337
+ cotangents = torch .randn_like (result , device = device )
1338
+ self ._compare_jacobians_of_vjp (torch .nn .functional .mse_loss , (cotangents , input , target ))
1339
+
1340
+ def test_extremal_numerics_softmax (self , device ):
1341
+ N , C , H , W = 3 , 4 , 5 , 6
1342
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1343
+ kwargs_options = ({'dim' : 1 }, {})
1344
+ for shape in shapes :
1345
+ input_options = self ._make_extremal_inputs (shape , device )
1346
+ for input , kwargs in self ._arg_and_kwarg_options ((input_options ,), kwargs_options ):
1347
+ result = torch .nn .functional .softmax (input )
1348
+ cotangents = torch .randn_like (result , device = device )
1349
+ self ._compare_jacobians_of_vjp (torch .nn .functional .softmax , (cotangents , input ))
1350
+
1351
+
1352
+ def test_extremal_numerics_log_softmax (self , device ):
1353
+ N , C , H , W = 3 , 4 , 5 , 6
1354
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1355
+ kwargs_options = ({'dim' : 1 }, {})
1356
+ for shape in shapes :
1357
+ input_options = self ._make_extremal_inputs (shape , device )
1358
+ for input , kwargs in self ._arg_and_kwarg_options ((input_options ,), kwargs_options ):
1359
+ result = torch .nn .functional .log_softmax (input )
1360
+ cotangents = torch .randn_like (result , device = device )
1361
+ self ._compare_jacobians_of_vjp (torch .nn .functional .log_softmax , (cotangents , input ))
1362
+
1363
+ def test_extremal_numerics_cross_entropy (self , device ):
1364
+ N , C = 3 , 4
1365
+ d1 , d2 , d3 = 5 , 6 , 7
1366
+ shapes = (
1367
+ ((N , C ), (N ,), (C ,)),
1368
+ ((N , C ), (N ,), None ),
1369
+ ((N , C ), (N , C ), (C ,)),
1370
+ ((N , C ), (N , C ), None ),
1371
+ ((C ,), (), (C ,)),
1372
+ ((C ,), (), None ),
1373
+ ((C ,), (C ,), (C ,)),
1374
+ ((C ,), (C ,), None ),
1375
+ ((N , C , d1 , d2 , d3 ), (N , d1 , d2 , d3 ), (C ,)),
1376
+ ((N , C , d1 , d2 , d3 ), (N , d1 , d2 , d3 ), None ),
1377
+ ((N , C , d1 , d2 , d3 ), (N , C , d1 , d2 , d3 ), (C ,)),
1378
+ ((N , C , d1 , d2 , d3 ), (N , C , d1 , d2 , d3 ), None ),
1379
+ )
1380
+ for input_shape , target_shape , weight_shape in shapes :
1381
+ input_options = self ._make_extremal_inputs (input_shape , device )
1382
+ kwargs_options = [{'reduction' : 'sum' }, {'reduction' : 'none' }, {}]
1383
+ if input_shape != target_shape :
1384
+ kwargs_options .append ({'ignore_index' : 0 , 'reduction' : 'mean' })
1385
+
1386
+ for input , kwargs in self ._arg_and_kwarg_options ((input_options ,), kwargs_options ):
1387
+ if weight_shape is None :
1388
+ weight = None
1389
+ else :
1390
+ weight = torch .randn (weight_shape , device = device )
1391
+
1392
+ if input_shape == target_shape :
1393
+ target = torch .rand (target_shape , device = device )
1394
+ elif len (target_shape ) == 0 :
1395
+ target = torch .tensor (1 , device = device ) # must be non-zero since ignore_index may be 0
1396
+ else :
1397
+ target = torch .randint (0 , C , target_shape , device = device )
1398
+
1399
+ fn = functools .partial (torch .nn .functional .cross_entropy , target = target , weight = weight , ** kwargs )
1400
+ result = fn (input )
1401
+ cotangents = torch .randn_like (result , device = device )
1402
+ self ._compare_jacobians_of_vjp (fn , (cotangents , input ), atol_rtol = (1e-4 , 1e-5 ))
1403
+
1404
+ def test_extremal_numerics_binary_cross_entropy (self , device ):
1405
+ N , C , H , W = 3 , 4 , 5 , 6
1406
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1407
+ for shape in shapes :
1408
+ weight_options = self ._make_extremal_inputs (shape , device )
1409
+ kwargs_options = [{'reduction' : 'sum' }, {'reduction' : 'none' }, {}]
1410
+
1411
+ for weight , kwargs in self ._arg_and_kwarg_options ((weight_options ,), kwargs_options ):
1412
+ input = torch .rand (shape , device = device )
1413
+ target = torch .rand (shape , device = device )
1414
+ fn = functools .partial (torch .nn .functional .binary_cross_entropy , target = target , weight = weight , ** kwargs )
1415
+ result = fn (input )
1416
+ cotangents = torch .randn_like (result , device = device )
1417
+ self ._compare_jacobians_of_vjp (fn , (cotangents , input ), atol_rtol = (1e-4 , 2e-5 ))
1418
+
1419
+ def test_extremal_numerics_layer_norm (self , device ):
1420
+ N , C , H , W = 3 , 4 , 5 , 6
1421
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
1422
+ for shape in shapes :
1423
+ input_options = self ._make_extremal_inputs (shape , device )
1424
+ normalized_shape = shape [1 :]
1425
+ weight_options = self ._make_extremal_inputs (normalized_shape , device )
1426
+ bias_options = self ._make_extremal_inputs (normalized_shape , device )
1427
+
1428
+ for input , bias , weight in self ._arg_and_kwarg_options ((input_options , bias_options , weight_options ), ()):
1429
+ def fn (input , weight , bias ):
1430
+ return torch .nn .functional .layer_norm (input , normalized_shape , weight = weight , bias = bias )
1431
+ result = fn (input , weight , bias )
1432
+ cotangents = torch .randn_like (result , device = device )
1433
+ self ._compare_jacobians_of_vjp (fn , (cotangents , input , weight , bias ))
1434
+
1274
1435
@ops (filter (lambda op : op .name == "nn.functional.group_norm" , functorch_lagging_op_db + additional_op_db ),
1275
1436
allowed_dtypes = (torch .float32 , torch .double )) # TODO: generalize
1276
1437
def test_group_norm_backward (self , device , dtype , op ):
0 commit comments