@@ -2969,6 +2969,7 @@ def addcmul_addcdiv_sample_generator(op, device, dtype, requires_grad, **kwargs)
2969
2969
DecorateInfo (
2970
2970
pytest .mark .xfail ,
2971
2971
"test_vjp_correctness" ,
2972
+ devicetypes = (devices .DeviceType .CUDA ,),
2972
2973
),
2973
2974
),
2974
2975
)
@@ -3832,10 +3833,6 @@ def expand_error_generator(op, device, *, dtype=torch.float32, **kwargs):
3832
3833
sample_input_generator = expand_sample_generator ,
3833
3834
error_input_generator = expand_error_generator ,
3834
3835
torch_reference = torch .Tensor .expand ,
3835
- test_directives = (
3836
- # vjp not yet implemented
3837
- DecorateInfo (pytest .mark .xfail , "test_vjp_correctness" ),
3838
- ),
3839
3836
)
3840
3837
shape_ops .append (expand_opinfo )
3841
3838
@@ -3880,10 +3877,6 @@ def expand_as_error_generator(op, device, *, dtype=torch.float32, **kwargs):
3880
3877
sample_input_generator = expand_as_sample_generator ,
3881
3878
error_input_generator = expand_as_error_generator ,
3882
3879
torch_reference = torch .Tensor .expand_as ,
3883
- test_directives = (
3884
- # vjp not yet implemented
3885
- DecorateInfo (pytest .mark .xfail , "test_vjp_correctness" ),
3886
- ),
3887
3880
)
3888
3881
shape_ops .append (expand_as_opinfo )
3889
3882
@@ -4319,12 +4312,6 @@ def make_nd_idx(dim_length: int, indices: int, ndim: int):
4319
4312
),
4320
4313
DecorateInfo (pytest .mark .xfail , "test_vjp_correctness" , active_if = IS_WINDOWS ),
4321
4314
DecorateInfo (pytest .mark .xfail , "test_phantom_grad_vs_torch_consistency" , active_if = IS_WINDOWS ),
4322
- # TODO: https://github.com/Lightning-AI/lightning-thunder/issues/841
4323
- # check_slice_value(p0, slice(1, 3, 1)) in prologue trace fails
4324
- DecorateInfo (
4325
- pytest .mark .xfail ,
4326
- "test_vjp_correctness" ,
4327
- ),
4328
4315
),
4329
4316
)
4330
4317
shape_ops .append (getitem_opinfo )
0 commit comments