Skip to content

Commit 5dde9b7

Browse files
authored
Embedding backward batch rule (#355)
Test Plan: - run tests
1 parent f7a3576 commit 5dde9b7

File tree

5 files changed

+61
-30
lines changed

5 files changed

+61
-30
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
236236
OP_DECOMPOSE(_convolution_mode);
237237
OP_DECOMPOSE(frobenius_norm);
238238
OP_DECOMPOSE(type_as);
239+
OP_DECOMPOSE(embedding_backward);
239240
DECOMPOSE_FUNCTIONAL(diag_embed);
240241
DECOMPOSE_FUNCTIONAL(block_diag);
241242
}

functorch/csrc/BatchRulesModules.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010

1111
namespace at { namespace functorch {
1212

13+
static Tensor getStepTensor(const Tensor& indices, int64_t bdim_size, int64_t num_embeddings) {
14+
// [batch_size, 1, 1, 1, ..., 1]
15+
DimVector view_shape(indices.dim(), 1);
16+
view_shape[0] = bdim_size;
17+
auto range = at::arange(0, bdim_size * num_embeddings, num_embeddings, indices.options());
18+
return range.view(view_shape);
19+
}
20+
1321
std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
1422
const Tensor& weight, optional<int64_t> weight_bdim,
1523
const Tensor& indices, optional<int64_t> indices_bdim,
@@ -34,18 +42,43 @@ std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
3442
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
3543
auto indices_ = moveBatchDimToFront(indices, indices_bdim);
3644

37-
// [batch_size, 1, 1, 1, ..., 1]
38-
DimVector view_shape(indices_.dim(), 1);
39-
view_shape[0] = batch_size;
40-
41-
auto range = at::arange(0, batch_size * num_embeddings, num_embeddings, indices_.options());
42-
range = range.view(view_shape);
43-
45+
const auto range = getStepTensor(indices, batch_size, num_embeddings);
4446
indices_ = indices_ + range;
4547
const auto result = at::embedding(weight_, indices_, padding_idx, scale_grad_by_freq, sparse);
4648
return std::make_tuple(result, 0);
4749
}
4850

51+
std::tuple<Tensor,optional<int64_t>>
52+
embedding_dense_backward_batch_rule(
53+
const Tensor& grad_, optional<int64_t> grad_bdim,
54+
const Tensor& indices_, optional<int64_t> indices_bdim,
55+
int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) {
56+
Tensor grad = grad_;
57+
Tensor indices = indices_;
58+
if (!indices_bdim && grad_bdim) {
59+
const auto bdim_size = grad.size(*grad_bdim);
60+
grad = reshape_dim_into(*grad_bdim, -1, grad);
61+
auto result = at::embedding_dense_backward(
62+
grad, indices, num_weights, padding_idx, scale_grad_by_freq);
63+
result = reshape_dim_outof(1, bdim_size, result);
64+
return std::make_tuple(result, 1);
65+
}
66+
const auto bdim_size = indices.size(*indices_bdim);
67+
indices = moveBatchDimToFront(indices, indices_bdim);
68+
grad = moveBatchDimToFront(grad, grad_bdim);
69+
grad = ensure_has_bdim(grad, grad_bdim.has_value(), bdim_size);
70+
const auto range = getStepTensor(indices, bdim_size, num_weights);
71+
auto result = at::embedding_dense_backward(
72+
grad, indices + range, num_weights * bdim_size, -1, scale_grad_by_freq);
73+
result = reshape_dim_outof(0, bdim_size, result);
74+
// Fill in the padding. We can't do it in the embedding_dense_backward call
75+
// because we need to fill in multiple rows!
76+
if (padding_idx >= 0) {
77+
result.select(1, padding_idx).fill_(0);
78+
}
79+
return std::make_tuple(result, 0);
80+
}
81+
4982
/**
5083
* grid sample batch rule breaks down into 3 cases:
5184
* case 1 (input is batched, grid is not):
@@ -358,6 +391,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
358391
EXISTING_BDIM(im2col_backward);
359392

360393
VMAP_SUPPORT("embedding", embedding_batch_rule);
394+
VMAP_SUPPORT("embedding_dense_backward", embedding_dense_backward_batch_rule);
361395

362396
VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
363397
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));

test/discover_coverage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def print_coverage_info(th=100, nn=25):
407407
'torch.prod', # dynamic (backward)
408408
'torch.norm', # norm with nuc is not commonly used; we support the other cases.
409409
'torch.svd', # There isn't a bug, it is just nondeterministic so we can't test it.
410+
'torch.nn.functional.embedding', # We support everything except the sparse option.
410411
}
411412
remove_from_set(statuses['test_vmap_exhaustive'], vmap_exemptions)
412413
remove_from_set(statuses['test_vmapvjp'], vmap_exemptions)

test/functorch_additional_op_db.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def sample_inputs_cross_entropy(self, device, dtype, requires_grad, reduction):
201201
supports_out=True))
202202

203203

204-
# TODO: split embedding in pytorch core
204+
# TODO: PyTorch core has a check for if requires_grad=True or not.
205+
# We actually want to test more things for backward here which is why we have our own
205206
def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
206207
def make_input(shape):
207208
return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -225,27 +226,22 @@ def generator():
225226
idx = make_long_input((S, S), low=0, high=M)
226227
yield SampleInput(make_input((M, S)), args=(idx,),)
227228

228-
if not requires_grad:
229-
# Following inputs return different gradient from the numerical gradient.
230-
# This is expected and relevant tests are present in `test_nn.py`.
231-
232-
# The gradient vector at `padding_idx` is not updated.
233-
idx = make_long_input((2, 2), low=0, high=S)
234-
idx[0, 0] = 2
235-
idx[1, 1] = 2
236-
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
237-
238-
idx = make_long_input((2, 2), low=0, high=S)
239-
idx[0, 0] = 4
240-
idx[1, 1] = 4
241-
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
242-
243-
# Scale the gradient based on the inverse frequency of a particular index.
244-
idx = make_long_input((2, 2), low=0, high=S)
245-
idx[0, 0] = 1
246-
idx[0, 1] = 1
247-
weights = make_input((S, S))
248-
yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
229+
idx = make_long_input((2, 2), low=0, high=S)
230+
idx[0, 0] = 2
231+
idx[1, 1] = 2
232+
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
233+
234+
idx = make_long_input((2, 2), low=0, high=S)
235+
idx[0, 0] = 4
236+
idx[1, 1] = 4
237+
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
238+
239+
# Scale the gradient based on the inverse frequency of a particular index.
240+
idx = make_long_input((2, 2), low=0, high=S)
241+
idx[0, 0] = 1
242+
idx[0, 1] = 1
243+
weights = make_input((S, S))
244+
yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
249245

250246
return list(generator())
251247

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,6 @@ def test_vmapjvpall(self, device, dtype, op):
818818
xfail('fft.ihfft2'),
819819
xfail('fft.ihfftn'),
820820
xfail('fft.rfft2'),
821-
xfail('nn.functional.embedding'),
822821
xfail('cross'),
823822
xfail('double', 'channels_last'),
824823
xfail('linalg.cross'),

0 commit comments

Comments
 (0)