Skip to content

Commit 55477e6

Browse files
committed
Enable interpolate VJP tests, add PrimID mapping for ceil and floor
1 parent d001dbe commit 55477e6

File tree

2 files changed

+2
-14
lines changed

2 files changed

+2
-14
lines changed

thunder/core/transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,6 +2574,8 @@ def uniform_backward(primal, minval, maxval, g):
25742574
prims.PrimIDs.BITWISE_XOR,
25752575
prims.PrimIDs.SIGNBIT,
25762576
prims.PrimIDs.FULL,
2577+
prims.PrimIDs.FLOOR,
2578+
prims.PrimIDs.CEIL,
25772579
}
25782580

25792581

thunder/tests/opinfos.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9839,20 +9839,6 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs):
98399839
error_input_generator=interpolate_error_generator,
98409840
torch_reference=torch.nn.functional.interpolate,
98419841
dtypes=(datatypes.floating,),
9842-
test_directives=(
9843-
# PyTorch does not support CPU Half upsample used in interpolate
9844-
DecorateInfo(
9845-
pytest.mark.xfail,
9846-
"test_core_vs_torch_consistency",
9847-
dtypes=(datatypes.float16,),
9848-
devicetypes=(devices.DeviceType.CPU,),
9849-
),
9850-
# This should be fixed now; TODO re-enable and test
9851-
DecorateInfo(
9852-
pytest.mark.xfail,
9853-
"test_vjp_correctness",
9854-
),
9855-
),
98569842
)
98579843
nn_ops.append(interpolate_opinfo)
98589844

0 commit comments

Comments
 (0)