Skip to content

Commit 3b68942

Browse files
author
samdow
committed
embedding decomp
1 parent 825f439 commit 3b68942

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

functorch/_src/eager_transforms.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,6 @@ def _register_python_decomposition_vmap(decomp):
13391339
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
13401340
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)
13411341
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
1342-
_register_jit_decomposition(torch.ops.aten.native_batch_norm_backward.default)
1343-
_register_jit_decomposition(torch.ops.aten.cudnn_batch_norm_backward.default)
1342+
_register_jit_decomposition(torch.ops.aten.embedding_dense_backward.default)
13441343
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
13451344
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/DynamicLayer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
504504
JVP_DECOMP(native_layer_norm_backward);
505505
JVP_DECOMP(native_batch_norm_backward);
506506
JVP_DECOMP(cudnn_batch_norm_backward);
507+
JVP_DECOMP(embedding_dense_backward);
507508
}
508509

509510

0 commit comments

Comments
 (0)