@@ -1019,6 +1019,120 @@ def test_vjpvmap(self, device, dtype, op):
1019
1019
1020
1020
self .assertEqual (result_vjps , expected_vjps )
1021
1021
1022
+ @ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
1023
+ @skipOps ('TestOperators' , 'test_jvpvjp' , vjp_fail .union ({
1024
+ # These are weirdly non-deterministic
1025
+ skip ('nn.functional.conv2d' , '' , device_type = 'cpu' ),
1026
+ skip ('nn.functional.conv2d' , 'no_bias' , device_type = 'cpu' ),
1027
+ skip ('nn.functional.conv2d' , 'stride_no_bias' , device_type = 'cpu' ),
1028
+ skip ('nn.functional.conv2d' , 'stride_padding_no_bias' , device_type = 'cpu' ),
1029
+ skip ('nn.functional.fractional_max_pool2d' ), # Random
1030
+ skip ('nn.functional.fractional_max_pool3d' ), # Random
1031
+
1032
+ xfail ('_masked.log_softmax' ),
1033
+ xfail ('_masked.softmax' ),
1034
+ xfail ('_masked.softmin' ),
1035
+ xfail ('block_diag' ),
1036
+ xfail ('cdist' ),
1037
+ xfail ('fft.fft' ),
1038
+ xfail ('fft.fft2' ),
1039
+ xfail ('fft.fftn' ),
1040
+ xfail ('fft.hfft' ),
1041
+ xfail ('fft.hfft2' ),
1042
+ xfail ('fft.hfftn' ),
1043
+ xfail ('fft.ifft' ),
1044
+ xfail ('fft.ifft2' ),
1045
+ xfail ('fft.ifftn' ),
1046
+ xfail ('fft.ihfft' ),
1047
+ xfail ('fft.ihfft2' ),
1048
+ xfail ('fft.ihfftn' ),
1049
+ xfail ('fft.irfft' ),
1050
+ xfail ('fft.irfft2' ),
1051
+ xfail ('fft.irfftn' ),
1052
+ xfail ('fft.rfft' ),
1053
+ xfail ('fft.rfft2' ),
1054
+ xfail ('fft.rfftn' ),
1055
+ xfail ('istft' ),
1056
+ xfail ('log_softmax' ),
1057
+ xfail ('log_softmax' , 'dtype' ),
1058
+ xfail ('logcumsumexp' ),
1059
+ xfail ('nn.functional.batch_norm' ),
1060
+ xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
1061
+ xfail ('nn.functional.bilinear' ),
1062
+ xfail ('nn.functional.binary_cross_entropy' ),
1063
+ xfail ('nn.functional.binary_cross_entropy_with_logits' , device_type = 'cuda' ),
1064
+ xfail ('nn.functional.celu' ),
1065
+ xfail ('nn.functional.cross_entropy' ),
1066
+ xfail ('nn.functional.cross_entropy' , 'mean' ),
1067
+ xfail ('nn.functional.cross_entropy' , 'none' ),
1068
+ xfail ('nn.functional.cross_entropy' , 'sum' ),
1069
+ xfail ('nn.functional.elu' ),
1070
+ xfail ('nn.functional.embedding' ),
1071
+ xfail ('nn.functional.embedding' , 'functorch' ),
1072
+ xfail ('nn.functional.embedding_bag' ),
1073
+ xfail ('nn.functional.glu' ),
1074
+ xfail ('nn.functional.grid_sample' ),
1075
+ xfail ('nn.functional.hardsigmoid' ),
1076
+ xfail ('nn.functional.hardswish' ),
1077
+ xfail ('nn.functional.huber_loss' ),
1078
+ xfail ('nn.functional.instance_norm' ),
1079
+ xfail ('nn.functional.layer_norm' ),
1080
+ xfail ('nn.functional.leaky_relu' ),
1081
+ xfail ('nn.functional.logsigmoid' ),
1082
+ xfail ('nn.functional.mse_loss' ),
1083
+ xfail ('nn.functional.nll_loss' ),
1084
+ xfail ('nn.functional.pad' , 'circular' ),
1085
+ xfail ('nn.functional.prelu' ),
1086
+ xfail ('nn.functional.selu' ),
1087
+ xfail ('nn.functional.softmin' ),
1088
+ xfail ('nn.functional.softmin' , 'with_dtype' ),
1089
+ xfail ('nn.functional.softplus' ),
1090
+ xfail ('put' ),
1091
+ xfail ('softmax' ),
1092
+ xfail ('softmax' , 'with_dtype' ),
1093
+ xfail ('stft' ),
1094
+ xfail ('take' ),
1095
+ }))
1096
+ def test_jvpvjp (self , device , dtype , op ):
1097
+ if not op .supports_autograd :
1098
+ self .skipTest ("Skipped! Autograd not supported." )
1099
+ return
1100
+
1101
+ samples = op .sample_inputs (device , dtype , requires_grad = True )
1102
+
1103
+ # TODO: test in-place
1104
+ if is_inplace (op , op .get_op ()):
1105
+ self .skipTest ("Skipped! NYI: inplace-testing not supported." )
1106
+ return
1107
+
1108
+ for sample in samples :
1109
+ fn , primals = normalize_op_input_output (op , sample )
1110
+ result = fn (* primals )
1111
+ cotangents = tree_map (lambda x : torch .randn_like (x ), result )
1112
+ tangents = tree_map (lambda x : torch .randn_like (x ), result )
1113
+
1114
+ _ , vjp_fn = vjp (fn , * primals )
1115
+ result = jvp (vjp_fn , (cotangents ,), (tangents ,))
1116
+ self .assertEqual (len (result ), 2 )
1117
+
1118
+ def reference (primals , cotangents , tangents ):
1119
+ _ , vjp_fn = ref_vjp (fn , * primals )
1120
+ with fwAD .dual_level ():
1121
+ flat_cotangents , spec = tree_flatten (cotangents )
1122
+ flat_tangents , spec = tree_flatten (tangents )
1123
+ flat_duals = [fwAD .make_dual (c , t ) for c , t in zip (flat_cotangents , flat_tangents )]
1124
+ duals = tree_unflatten (flat_duals , spec )
1125
+ result = vjp_fn (duals )
1126
+ flat_result , spec = tree_flatten (result )
1127
+ primals_out , tangents_out = zip (* [fwAD .unpack_dual (r ) for r in flat_result ])
1128
+ tangents_out = [t if t is not None else torch .zeros_like (p )
1129
+ for p , t in zip (primals_out , tangents_out )]
1130
+ expected = (tree_unflatten (primals_out , spec ), tree_unflatten (tangents_out , spec ))
1131
+ return expected
1132
+
1133
+ expected = reference (primals , cotangents , tangents )
1134
+ self .assertEqual (result , expected )
1135
+
1022
1136
1023
1137
class InplaceError (Exception ):
1024
1138
def __repr__ (self ):
0 commit comments