@@ -1055,69 +1055,109 @@ def test_vjpvmap(self, device, dtype, op):
1055
1055
skip ('nn.functional.fractional_max_pool2d' ), # Random
1056
1056
skip ('nn.functional.fractional_max_pool3d' ), # Random
1057
1057
1058
- xfail ('_masked.log_softmax' ),
1059
- xfail ('_masked.softmax' ),
1060
- xfail ('_masked.softmin' ),
1061
- xfail ('block_diag' ),
1062
- xfail ('cdist' ),
1063
- xfail ('fft.fft' ),
1064
- xfail ('fft.fft2' ),
1065
- xfail ('fft.fftn' ),
1066
- xfail ('fft.hfft' ),
1067
- xfail ('fft.hfft2' ),
1068
- xfail ('fft.hfftn' ),
1069
- xfail ('fft.ifft' ),
1070
- xfail ('fft.ifft2' ),
1071
- xfail ('fft.ifftn' ),
1072
- xfail ('fft.ihfft' ),
1073
- xfail ('fft.ihfft2' ),
1074
- xfail ('fft.ihfftn' ),
1075
- xfail ('fft.irfft' ),
1076
- xfail ('fft.irfft2' ),
1077
- xfail ('fft.irfftn' ),
1078
- xfail ('fft.rfft' ),
1079
- xfail ('fft.rfft2' ),
1080
- xfail ('fft.rfftn' ),
1081
- xfail ('istft' ),
1082
- xfail ('log_softmax' ),
1058
+ xfail ('__rsub__' , '' ),
1059
+ xfail ('_masked.amax' , '' ),
1060
+ xfail ('_masked.amin' , '' ),
1061
+ xfail ('_masked.log_softmax' , '' ),
1062
+ xfail ('_masked.norm' , '' ),
1063
+ xfail ('_masked.normalize' , '' ),
1064
+ xfail ('_masked.softmax' , '' ),
1065
+ xfail ('_masked.softmin' , '' ),
1066
+ xfail ('amax' , '' ),
1067
+ xfail ('amin' , '' ),
1068
+ xfail ('atan2' , '' ),
1069
+ xfail ('block_diag' , '' ),
1070
+ xfail ('cdist' , '' ),
1071
+ xfail ('cholesky' , '' ),
1072
+ xfail ('cholesky_inverse' , '' ),
1073
+ xfail ('dist' , '' ),
1074
+ xfail ('eig' , '' ),
1075
+ xfail ('fft.fft' , '' ),
1076
+ xfail ('fft.fft2' , '' ),
1077
+ xfail ('fft.fftn' , '' ),
1078
+ xfail ('fft.hfft' , '' ),
1079
+ xfail ('fft.hfft2' , '' ),
1080
+ xfail ('fft.hfftn' , '' ),
1081
+ xfail ('fft.ifft' , '' ),
1082
+ xfail ('fft.ifft2' , '' ),
1083
+ xfail ('fft.ifftn' , '' ),
1084
+ xfail ('fft.ihfft' , '' ),
1085
+ xfail ('fft.ihfft2' , '' ),
1086
+ xfail ('fft.ihfftn' , '' ),
1087
+ xfail ('fft.irfft' , '' ),
1088
+ xfail ('fft.irfft2' , '' ),
1089
+ xfail ('fft.irfftn' , '' ),
1090
+ xfail ('fft.rfft' , '' ),
1091
+ xfail ('fft.rfft2' , '' ),
1092
+ xfail ('fft.rfftn' , '' ),
1093
+ xfail ('istft' , '' ),
1094
+ xfail ('linalg.det' , '' ),
1095
+ xfail ('linalg.eigh' , '' ),
1096
+ xfail ('linalg.eigvalsh' , '' ),
1097
+ xfail ('linalg.matrix_norm' , '' ),
1098
+ xfail ('linalg.norm' , '' ),
1099
+ xfail ('linalg.slogdet' , '' ),
1100
+ xfail ('linalg.vector_norm' , '' ),
1101
+ xfail ('log_softmax' , '' ),
1083
1102
xfail ('log_softmax' , 'dtype' ),
1084
- xfail ('logcumsumexp' ),
1085
- xfail ('nn.functional.batch_norm' ),
1103
+ xfail ('logcumsumexp' , '' ),
1104
+ xfail ('logdet' , '' ),
1105
+ xfail ('lu' , '' ),
1106
+ xfail ('lu_solve' , '' ),
1107
+ xfail ('lu_unpack' , '' ),
1108
+ xfail ('max' , 'binary' ),
1109
+ xfail ('maximum' , '' ),
1110
+ xfail ('min' , 'binary' ),
1111
+ xfail ('minimum' , '' ),
1112
+ xfail ('nanmean' , '' ),
1113
+ xfail ('nansum' , '' ),
1114
+ xfail ('nn.functional.batch_norm' , '' ),
1086
1115
xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
1087
- xfail ('nn.functional.bilinear' ),
1088
- xfail ('nn.functional.binary_cross_entropy' ),
1089
- xfail ('nn.functional.binary_cross_entropy_with_logits' , device_type = 'cuda ' ),
1090
- xfail ('nn.functional.celu' ),
1091
- xfail ('nn.functional.cross_entropy' ),
1116
+ xfail ('nn.functional.bilinear' , '' ),
1117
+ xfail ('nn.functional.binary_cross_entropy' , '' ),
1118
+ xfail ('nn.functional.binary_cross_entropy_with_logits' , ' ' ),
1119
+ xfail ('nn.functional.celu' , '' ),
1120
+ xfail ('nn.functional.cross_entropy' , '' ),
1092
1121
xfail ('nn.functional.cross_entropy' , 'mean' ),
1093
1122
xfail ('nn.functional.cross_entropy' , 'none' ),
1094
1123
xfail ('nn.functional.cross_entropy' , 'sum' ),
1095
- xfail ('nn.functional.elu' ),
1096
- xfail ('nn.functional.embedding' ),
1124
+ xfail ('nn.functional.elu' , '' ),
1125
+ xfail ('nn.functional.embedding' , '' ),
1097
1126
xfail ('nn.functional.embedding' , 'functorch' ),
1098
- xfail ('nn.functional.embedding_bag' ),
1099
- xfail ('nn.functional.glu' ),
1100
- xfail ('nn.functional.grid_sample' ),
1101
- xfail ('nn.functional.hardsigmoid' ),
1102
- xfail ('nn.functional.hardswish' ),
1103
- xfail ('nn.functional.huber_loss' ),
1104
- xfail ('nn.functional.instance_norm' ),
1105
- xfail ('nn.functional.layer_norm' ),
1106
- xfail ('nn.functional.leaky_relu' ),
1107
- xfail ('nn.functional.logsigmoid' ),
1108
- xfail ('nn.functional.mse_loss' ),
1109
- xfail ('nn.functional.nll_loss' ),
1127
+ xfail ('nn.functional.embedding_bag' , '' ),
1128
+ xfail ('nn.functional.glu' , '' ),
1129
+ xfail ('nn.functional.grid_sample' , '' ),
1130
+ xfail ('nn.functional.hardsigmoid' , '' ),
1131
+ xfail ('nn.functional.hardswish' , '' ),
1132
+ xfail ('nn.functional.huber_loss' , '' ),
1133
+ xfail ('nn.functional.instance_norm' , '' ),
1134
+ xfail ('nn.functional.layer_norm' , '' ),
1135
+ xfail ('nn.functional.leaky_relu' , '' ),
1136
+ xfail ('nn.functional.logsigmoid' , '' ),
1137
+ xfail ('nn.functional.mse_loss' , '' ),
1138
+ xfail ('nn.functional.nll_loss' , '' ),
1139
+ xfail ('nn.functional.normalize' , '' ),
1110
1140
xfail ('nn.functional.pad' , 'circular' ),
1111
- xfail ('nn.functional.prelu' ),
1112
- xfail ('nn.functional.selu' ),
1113
- xfail ('nn.functional.softmin' ),
1141
+ xfail ('nn.functional.pairwise_distance' , '' ),
1142
+ xfail ('nn.functional.prelu' , '' ),
1143
+ xfail ('nn.functional.selu' , '' ),
1144
+ xfail ('nn.functional.softmin' , '' ),
1114
1145
xfail ('nn.functional.softmin' , 'with_dtype' ),
1115
- xfail ('nn.functional.softplus' ),
1116
- xfail ('put' ),
1117
- xfail ('softmax' ),
1146
+ xfail ('nn.functional.softplus' , '' ),
1147
+ xfail ('norm' , '' ),
1148
+ xfail ('norm' , 'fro' ),
1149
+ xfail ('norm' , 'inf' ),
1150
+ xfail ('polar' , '' ),
1151
+ xfail ('put' , '' ),
1152
+ xfail ('renorm' , '' ),
1153
+ xfail ('softmax' , '' ),
1118
1154
xfail ('softmax' , 'with_dtype' ),
1119
- xfail ('stft' ),
1120
- xfail ('take' ),
1155
+ xfail ('solve' , '' ),
1156
+ xfail ('std_mean' , '' ),
1157
+ xfail ('stft' , '' ),
1158
+ xfail ('symeig' , '' ),
1159
+ xfail ('take' , '' ),
1160
+ xfail ('var_mean' , '' ),
1121
1161
}))
1122
1162
def test_jvpvjp (self , device , dtype , op ):
1123
1163
if not op .supports_autograd :
@@ -1135,28 +1175,40 @@ def test_jvpvjp(self, device, dtype, op):
1135
1175
fn , primals = normalize_op_input_output (op , sample )
1136
1176
result = fn (* primals )
1137
1177
cotangents = tree_map (lambda x : torch .randn_like (x ), result )
1138
- tangents = tree_map (lambda x : torch .randn_like (x ), result )
1139
1178
1140
- _ , vjp_fn = vjp (fn , * primals )
1141
- result = jvp (vjp_fn , (cotangents ,), (tangents ,))
1179
+ primals_tangents = tree_map (lambda x : torch .randn_like (x ), primals )
1180
+ cotangents_tangents = tree_map (lambda x : torch .randn_like (x ), cotangents )
1181
+
1182
+ def push_vjp (primals , cotangents ):
1183
+ _ , vjp_fn = vjp (fn , * primals )
1184
+ return vjp_fn (cotangents )
1185
+
1186
+ result = jvp (push_vjp , (primals , cotangents ), (primals_tangents , cotangents_tangents ))
1142
1187
self .assertEqual (len (result ), 2 )
1143
1188
1144
- def reference (primals , cotangents , tangents ):
1145
- _ , vjp_fn = ref_vjp (fn , * primals )
1189
+ def tree_map2 (fn , first , second ):
1190
+ flat_first , spec_first = tree_flatten (first )
1191
+ flat_second , spec_second = tree_flatten (second )
1192
+ assert spec_first == spec_second
1193
+ flat_result = [fn (f , s ) for f , s in zip (flat_first , flat_second )]
1194
+ return tree_unflatten (flat_result , spec_first )
1195
+
1196
+ def reference (primals , cotangents , primals_tangents , cotangents_tangents ):
1146
1197
with fwAD .dual_level ():
1147
- flat_cotangents , spec = tree_flatten (cotangents )
1148
- flat_tangents , spec = tree_flatten (tangents )
1149
- flat_duals = [fwAD .make_dual (c , t ) for c , t in zip (flat_cotangents , flat_tangents )]
1150
- duals = tree_unflatten (flat_duals , spec )
1151
- result = vjp_fn (duals )
1198
+ primal_duals = tree_map2 (fwAD .make_dual , primals , primals_tangents )
1199
+ _ , vjp_fn = ref_vjp (fn , * primal_duals )
1200
+
1201
+ cotangent_duals = tree_map2 (fwAD .make_dual , cotangents , cotangents_tangents )
1202
+ result = vjp_fn (cotangent_duals )
1203
+
1152
1204
flat_result , spec = tree_flatten (result )
1153
1205
primals_out , tangents_out = zip (* [fwAD .unpack_dual (r ) for r in flat_result ])
1154
1206
tangents_out = [t if t is not None else torch .zeros_like (p )
1155
1207
for p , t in zip (primals_out , tangents_out )]
1156
1208
expected = (tree_unflatten (primals_out , spec ), tree_unflatten (tangents_out , spec ))
1157
1209
return expected
1158
1210
1159
- expected = reference (primals , cotangents , tangents )
1211
+ expected = reference (primals , cotangents , primals_tangents , cotangents_tangents )
1160
1212
self .assertEqual (result , expected )
1161
1213
1162
1214
0 commit comments