Skip to content

Commit 26d9b00

Browse files
committed
Enable interpolate VJP tests, add PrimID mapping for ceil and floor
1 parent d001dbe commit 26d9b00

File tree

2 files changed

+4
-14
lines changed

2 files changed

+4
-14
lines changed

thunder/core/transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,8 @@ def zeros_like(x):
16601660
prims.PrimIDs.COPY_: lambda x, y, grad_enabled: (prims.copy_(x, y, grad_enabled=grad_enabled), tuple()),
16611661
prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()),
16621662
prims.PrimIDs.BITCAST: lambda x, dtype: (prims.bitcast(x, dtype), (x.dtype,)),
1663+
prims.PrimIDs.CEIL: lambda x: (prims.ceil(x), (x,)),
1664+
prims.PrimIDs.FLOOR: lambda x: (prims.floor(x), (x,)),
16631665
}
16641666

16651667

@@ -1691,6 +1693,8 @@ def zeros_like(x):
16911693
prims.PrimIDs.COPY_: lambda g: (g, None),
16921694
prims.PrimIDs.CLONE: lambda g: g,
16931695
prims.PrimIDs.BITCAST: lambda x_dtype, g: (prims.bitcast(g, x_dtype), None),
1696+
prims.PrimIDs.CEIL: lambda x, g: g * 0,
1697+
prims.PrimIDs.FLOOR: lambda x, g: g * 0,
16941698
}
16951699

16961700

thunder/tests/opinfos.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9839,20 +9839,6 @@ def interpolate_error_generator(op, device, dtype=torch.float32, **kwargs):
98399839
error_input_generator=interpolate_error_generator,
98409840
torch_reference=torch.nn.functional.interpolate,
98419841
dtypes=(datatypes.floating,),
9842-
test_directives=(
9843-
# PyTorch does not support CPU Half upsample used in interpolate
9844-
DecorateInfo(
9845-
pytest.mark.xfail,
9846-
"test_core_vs_torch_consistency",
9847-
dtypes=(datatypes.float16,),
9848-
devicetypes=(devices.DeviceType.CPU,),
9849-
),
9850-
# This should be fixed now; TODO re-enable and test
9851-
DecorateInfo(
9852-
pytest.mark.xfail,
9853-
"test_vjp_correctness",
9854-
),
9855-
),
98569842
)
98579843
nn_ops.append(interpolate_opinfo)
98589844

0 commit comments

Comments
 (0)