Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 46c9b89

Browse files
authored
Remove some xfails for transform coverage (#910)
1 parent d6b7f86 commit 46c9b89

File tree

4 files changed

+5
-21
lines changed

4 files changed

+5
-21
lines changed

functorch/csrc/BatchRulesLinearAlgebra.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
197197
VARIADIC_BDIMS_BOXED(symeig);
198198
VARIADIC_BDIMS_BOXED(triangular_solve);
199199

200-
VARIADIC_BDIMS_BOXED(_det_lu_based_helper);
200+
VARIADIC_BDIMS_BOXED(_linalg_det);
201201
VARIADIC_BDIMS_BOXED(_lu_with_info);
202202
}
203203
}}

functorch/csrc/PyTorchOperatorHacks.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ Tensor binary_cross_entropy_with_logits_hack(
212212
const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});
213213

214214
Tensor loss;
215-
auto max_val = (-input).clamp_min_(0);
215+
auto max_val = (-input).clamp_min(0);
216216
if (pos_weight.defined()) {
217217
// pos_weight need to be broadcasted, thus mul(target) is not inplace.
218218
auto log_weight = (pos_weight - 1).mul(target).add_(1);

test/test_ops.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,6 @@ def wrapped_fn(*args, **kwargs):
379379
skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu
380380
skip('nn.functional.dropout2d', ''), # fails on cuda, runs okay on cpu
381381
382-
# The following don't have a forward-mode AD formula in PyTorch core
383-
# (check derivatives.yaml).
384-
xfail('var_mean'),
385-
xfail('std_mean'),
386-
387382
# =============================================
388383
# NB: The above failures also fail using PyTorch core's
389384
# forward-mode AD and vmap.
@@ -674,14 +669,11 @@ def test_vmapvjp(self, device, dtype, op):
674669
# skip because this is flaky depending on what the max_norm is!
675670
skip('nn.functional.embedding', ''),
676671
xfail('nn.functional.soft_margin_loss', ''),
677-
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
678672
xfail('linalg.householder_product'),
679673
xfail('tensor_split'),
680674
xfail('quantile'),
681-
xfail('var_mean'),
682675
xfail('as_strided'),
683676
xfail('nn.functional.gaussian_nll_loss'),
684-
xfail('std_mean'),
685677
xfail('scatter'),
686678
xfail('matrix_exp'),
687679
xfail('nanquantile'),
@@ -765,6 +757,8 @@ def test_vmapjvpall(self, device, dtype, op):
765757
xfail('nn.functional.max_pool3d'),
766758
xfail('vdot'),
767759
xfail('linalg.cross'),
760+
xfail('nanmean'),
761+
xfail('nansum'),
768762
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
769763
xfail('linalg.lu_factor', ''),
770764
xfail('nn.functional.dropout2d', ''),
@@ -782,7 +776,6 @@ def test_vmapjvpall(self, device, dtype, op):
782776
xfail('nn.functional.smooth_l1_loss', ''),
783777
xfail('nn.functional.max_unpool2d', 'grad'),
784778
xfail('nn.functional.soft_margin_loss', ''),
785-
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
786779
xfail('nn.functional.max_unpool1d', 'grad'),
787780
xfail('nn.functional.embedding', ''),
788781
xfail('lu_unpack'),
@@ -1044,23 +1037,16 @@ def get_vjp(cotangents, *primals):
10441037
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
10451038
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
10461039
xfail('normal', ''),
1047-
xfail('_masked.amax', ''),
1048-
xfail('_masked.amin', ''),
10491040
xfail('_masked.log_softmax', ''),
10501041
xfail('_masked.softmax', ''),
10511042
xfail('_masked.softmin', ''),
1052-
xfail('amax', ''),
1053-
xfail('amin', ''),
10541043
xfail('cdist', ''),
10551044
xfail('cholesky', ''),
10561045
xfail('eig', ''),
10571046
xfail('linalg.det', ''),
1058-
xfail('linalg.matrix_norm', ''),
10591047
xfail('linalg.slogdet', ''),
10601048
xfail('logcumsumexp', ''),
10611049
xfail('logdet', ''),
1062-
xfail('nanmean', ''),
1063-
xfail('nansum', ''),
10641050
xfail('nn.functional.embedding_bag', ''),
10651051
xfail('nn.functional.grid_sample', ''),
10661052
xfail('nn.functional.hardsigmoid', ''),
@@ -1070,9 +1056,7 @@ def get_vjp(cotangents, *primals):
10701056
xfail('nn.functional.softmin', ''),
10711057
xfail('nn.functional.softmin', 'with_dtype'),
10721058
xfail('renorm', ''),
1073-
xfail('std_mean', ''),
10741059
xfail('symeig', ''),
1075-
xfail('var_mean', ''),
10761060
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
10771061
xfail('nn.functional.kl_div', ''),
10781062
xfail('pca_lowrank', ''),
@@ -1090,7 +1074,6 @@ def get_vjp(cotangents, *primals):
10901074
xfail('scatter_reduce', 'mean'),
10911075
xfail('scatter_reduce', 'prod'),
10921076
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
1093-
xfail('nn.functional.binary_cross_entropy_with_logits'),
10941077
}))
10951078
def test_jvpvjp(self, device, dtype, op):
10961079
if not op.supports_autograd:

test/test_vmap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,6 +3177,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
31773177
xfail('histogram'),
31783178
xfail('index_fill'),
31793179
xfail('nansum'),
3180+
xfail('nanmean'),
31803181
# `index_put` OpInfo in pytorch/pytorch has
31813182
# masked index as input which is not supported
31823183
xfail('index_put', ''),

0 commit comments

Comments
 (0)