Skip to content

Commit 4cc9705

Browse files
authored
tests: clean up xfails in VJP tests (#2578)
1 parent dc83ea4 commit 4cc9705

File tree

1 file changed

+1
-14
lines changed

1 file changed

+1
-14
lines changed

thunder/tests/opinfos.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,6 +2969,7 @@ def addcmul_addcdiv_sample_generator(op, device, dtype, requires_grad, **kwargs)
29692969
DecorateInfo(
29702970
pytest.mark.xfail,
29712971
"test_vjp_correctness",
2972+
devicetypes=(devices.DeviceType.CUDA,),
29722973
),
29732974
),
29742975
)
@@ -3832,10 +3833,6 @@ def expand_error_generator(op, device, *, dtype=torch.float32, **kwargs):
38323833
sample_input_generator=expand_sample_generator,
38333834
error_input_generator=expand_error_generator,
38343835
torch_reference=torch.Tensor.expand,
3835-
test_directives=(
3836-
# vjp not yet implemented
3837-
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
3838-
),
38393836
)
38403837
shape_ops.append(expand_opinfo)
38413838

@@ -3880,10 +3877,6 @@ def expand_as_error_generator(op, device, *, dtype=torch.float32, **kwargs):
38803877
sample_input_generator=expand_as_sample_generator,
38813878
error_input_generator=expand_as_error_generator,
38823879
torch_reference=torch.Tensor.expand_as,
3883-
test_directives=(
3884-
# vjp not yet implemented
3885-
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
3886-
),
38873880
)
38883881
shape_ops.append(expand_as_opinfo)
38893882

@@ -4319,12 +4312,6 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
43194312
),
43204313
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness", active_if=IS_WINDOWS),
43214314
DecorateInfo(pytest.mark.xfail, "test_phantom_grad_vs_torch_consistency", active_if=IS_WINDOWS),
4322-
# TODO: https://github.com/Lightning-AI/lightning-thunder/issues/841
4323-
# check_slice_value(p0, slice(1, 3, 1)) in prologue trace fails
4324-
DecorateInfo(
4325-
pytest.mark.xfail,
4326-
"test_vjp_correctness",
4327-
),
43284315
),
43294316
)
43304317
shape_ops.append(getitem_opinfo)

0 commit comments

Comments
 (0)