Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 091d999

Browse files
authored
Remove cross_entropy monkey patch (#901)
We don't need it anymore because we upstreamed the changes into pytorch/pytorch a while ago.
1 parent 6452ee3 commit 091d999

File tree

1 file changed

+0
-20
lines changed

1 file changed

+0
-20
lines changed

functorch/_src/monkey_patching.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,6 @@ def _functorch_str(tensor, *, tensor_contents=None):
6262
torch._tensor_str._str = _functorch_str
6363

6464

65-
_old_cross_entropy = torch.nn.functional.cross_entropy
66-
67-
68-
# **kwargs to handle the new label_smoothing arg
69-
def cross_entropy(input, target, weight=None, size_average=None,
70-
ignore_index=-100, reduce=None, reduction='mean', **kwargs):
71-
if input.dim() == 1 and target.dim() == 0:
72-
input = input.unsqueeze(0)
73-
target = target.unsqueeze(0)
74-
75-
result = _old_cross_entropy(
76-
input, target, weight, size_average,
77-
ignore_index, reduce, reduction, **kwargs)
78-
if reduction == 'none':
79-
return result.squeeze(0)
80-
return result
81-
82-
83-
torch.nn.functional.cross_entropy = cross_entropy
84-
8565
# Monkeypatch .backward() to error out if any transforms are active.
8666
# TODO: remove the monkeypatching and add an extension point into PyTorch core
8767
_old_backward = torch.Tensor.backward

0 commit comments

Comments
 (0)