diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 0d1a930ca8..500431ce33 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2574,6 +2574,8 @@ def uniform_backward(primal, minval, maxval, g): prims.PrimIDs.BITWISE_XOR, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL, + prims.PrimIDs.FLOOR, + prims.PrimIDs.CEIL, } diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index b7123c6d30..5cebec957f 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -9839,20 +9839,6 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs): error_input_generator=interpolate_error_generator, torch_reference=torch.nn.functional.interpolate, dtypes=(datatypes.floating,), - test_directives=( - # PyTorch does not support CPU Half upsample used in interpolate - DecorateInfo( - pytest.mark.xfail, - "test_core_vs_torch_consistency", - dtypes=(datatypes.float16,), - devicetypes=(devices.DeviceType.CPU,), - ), - # This should be fixed now; TODO re-enable and test - DecorateInfo( - pytest.mark.xfail, - "test_vjp_correctness", - ), - ), ) nn_ops.append(interpolate_opinfo)