Skip to content

Commit ed9673d

Browse files
authored
Embedding batch rule (#351)
1 parent 863602a commit ed9673d

File tree

4 files changed

+100
-1
lines changed

4 files changed

+100
-1
lines changed

functorch/csrc/BatchRulesModules.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,41 @@ std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & sel
217217
return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
218218
}
219219

220+
std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
221+
const Tensor& weight, optional<int64_t> weight_bdim,
222+
const Tensor& indices, optional<int64_t> indices_bdim,
223+
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
224+
if (!weight_bdim && indices_bdim) {
225+
// B*, ED -> B*D
226+
const auto result = at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
227+
return std::make_tuple(result, indices_bdim);
228+
} else if (weight_bdim && !indices_bdim) {
229+
// *, BED -> *, E(BD) -> *(BD) -> *BD
230+
const auto batch_size = weight.size(*weight_bdim);
231+
const auto weight_ = reshape_dim_into(*weight_bdim, /*embedding_dim*/1, weight);
232+
auto result = at::embedding(weight_, indices, padding_idx, scale_grad_by_freq, sparse);
233+
result = reshape_dim_outof(-1, batch_size, result);
234+
return std::make_tuple(result, result.dim() - 2);
235+
}
236+
TORCH_INTERNAL_ASSERT(weight_bdim && indices_bdim);
237+
// B*, BED -> B*, (BE)D -> B*D
238+
// We'll need to do something extra: add (0, E, 2*E, ...) to the indices.
239+
const auto batch_size = weight.size(*weight_bdim);
240+
const auto num_embeddings = weight.size((*weight_bdim == 0) ? 1 : 0);
241+
const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight);
242+
auto indices_ = moveBatchDimToFront(indices, indices_bdim);
243+
244+
// [batch_size, 1, 1, 1, ..., 1]
245+
DimVector view_shape(indices_.dim(), 1);
246+
view_shape[0] = batch_size;
247+
248+
auto range = at::arange(0, batch_size * num_embeddings, num_embeddings, indices_.options());
249+
range = range.view(view_shape);
250+
251+
indices_ = indices_ + range;
252+
const auto result = at::embedding(weight_, indices_, padding_idx, scale_grad_by_freq, sparse);
253+
return std::make_tuple(result, 0);
254+
}
220255

221256
/**
222257
* grid sample batch rule breaks down into 3 cases:
@@ -535,6 +570,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
535570
EXISTING_BDIM(im2col);
536571
EXISTING_BDIM(im2col_backward);
537572

573+
VMAP_SUPPORT("embedding", embedding_batch_rule);
574+
538575
VMAP_SUPPORT("grid_sampler_2d", GRID_SAMPLE_BATCH_RULE(grid_sampler));
539576
VMAP_SUPPORT("grid_sampler_2d_backward", GRID_SAMPLE_BW_BATCH_RULE(grid_sampler_2d_backward));
540577

test/common_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, compute_loop_
9090
# def f(a):
9191
# return op(a)
9292
# t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
93+
# print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
9394
batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
9495
yield (loop_out, batched_out)
9596

test/functorch_additional_op_db.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,64 @@ def sample_inputs_cross_entropy(self, device, dtype, requires_grad, reduction):
224224
dtypes=floating_types(),
225225
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
226226
supports_out=True))
227+
228+
# TODO: split embedding in pytorch core
229+
def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
230+
def make_input(shape):
231+
return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
232+
233+
def make_long_input(shape, *, low, high):
234+
return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
235+
236+
M = 20
237+
S = 5
238+
239+
def generator():
240+
# 0-D index tensor
241+
idx = make_long_input((), low=0, high=M)
242+
yield SampleInput(make_input((M, S)), args=(idx,),)
243+
244+
# 1-D index tensor
245+
idx = make_long_input((S,), low=0, high=M)
246+
yield SampleInput(make_input((M, S)), args=(idx,),)
247+
248+
# 2-D index tensor
249+
idx = make_long_input((S, S), low=0, high=M)
250+
yield SampleInput(make_input((M, S)), args=(idx,),)
251+
252+
if not requires_grad:
253+
# Following inputs return different gradient from the numerical gradient.
254+
# This is expected and relevant tests are present in `test_nn.py`.
255+
256+
# The gradient vector at `padding_idx` is not updated.
257+
idx = make_long_input((2, 2), low=0, high=S)
258+
idx[0, 0] = 2
259+
idx[1, 1] = 2
260+
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
261+
262+
idx = make_long_input((2, 2), low=0, high=S)
263+
idx[0, 0] = 4
264+
idx[1, 1] = 4
265+
yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
266+
267+
# Scale the gradient based on the inverse frequency of a particular index.
268+
idx = make_long_input((2, 2), low=0, high=S)
269+
idx[0, 0] = 1
270+
idx[0, 1] = 1
271+
weights = make_input((S, S))
272+
yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
273+
274+
return list(generator())
275+
276+
additional_op_db.append(
277+
OpInfo(
278+
"nn.functional.embedding",
279+
variant_test_name="functorch",
280+
# We use lambda to reshuffle the positional arguments.
281+
# This is because currently only the `input` field of SampleInput
282+
# is tested in gradient tests.
283+
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs),
284+
dtypes=floating_types_and(torch.bfloat16, torch.float16),
285+
sample_inputs_func=sample_inputs_embedding,
286+
supports_out=False,
287+
))

test/test_vmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3109,7 +3109,7 @@ class TestVmapOperatorsOpInfo(TestCase):
31093109
xfail('nn.functional.batch_norm'),
31103110
xfail('lu_unpack'),
31113111
xfail('histogramdd'),
3112-
xfail('nn.functional.embedding'),
3112+
xfail('nn.functional.embedding', ''),
31133113
xfail('randn_like'),
31143114
xfail('allclose'),
31153115
xfail('bfloat16', 'channels_last'),

0 commit comments

Comments
 (0)