10
10
11
11
namespace at { namespace functorch {
12
12
13
+ static Tensor getStepTensor (const Tensor& indices, int64_t bdim_size, int64_t num_embeddings) {
14
+ // [batch_size, 1, 1, 1, ..., 1]
15
+ DimVector view_shape (indices.dim (), 1 );
16
+ view_shape[0 ] = bdim_size;
17
+ auto range = at::arange (0 , bdim_size * num_embeddings, num_embeddings, indices.options ());
18
+ return range.view (view_shape);
19
+ }
20
+
13
21
std::tuple<Tensor,optional<int64_t >> embedding_batch_rule (
14
22
const Tensor& weight, optional<int64_t > weight_bdim,
15
23
const Tensor& indices, optional<int64_t > indices_bdim,
@@ -34,18 +42,43 @@ std::tuple<Tensor,optional<int64_t>> embedding_batch_rule(
34
42
const auto weight_ = reshape_dim_into (*weight_bdim, 0 , weight);
35
43
auto indices_ = moveBatchDimToFront (indices, indices_bdim);
36
44
37
- // [batch_size, 1, 1, 1, ..., 1]
38
- DimVector view_shape (indices_.dim (), 1 );
39
- view_shape[0 ] = batch_size;
40
-
41
- auto range = at::arange (0 , batch_size * num_embeddings, num_embeddings, indices_.options ());
42
- range = range.view (view_shape);
43
-
45
+ const auto range = getStepTensor (indices, batch_size, num_embeddings);
44
46
indices_ = indices_ + range;
45
47
const auto result = at::embedding (weight_, indices_, padding_idx, scale_grad_by_freq, sparse);
46
48
return std::make_tuple (result, 0 );
47
49
}
48
50
51
+ std::tuple<Tensor,optional<int64_t >>
52
+ embedding_dense_backward_batch_rule (
53
+ const Tensor& grad_, optional<int64_t > grad_bdim,
54
+ const Tensor& indices_, optional<int64_t > indices_bdim,
55
+ int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) {
56
+ Tensor grad = grad_;
57
+ Tensor indices = indices_;
58
+ if (!indices_bdim && grad_bdim) {
59
+ const auto bdim_size = grad.size (*grad_bdim);
60
+ grad = reshape_dim_into (*grad_bdim, -1 , grad);
61
+ auto result = at::embedding_dense_backward (
62
+ grad, indices, num_weights, padding_idx, scale_grad_by_freq);
63
+ result = reshape_dim_outof (1 , bdim_size, result);
64
+ return std::make_tuple (result, 1 );
65
+ }
66
+ const auto bdim_size = indices.size (*indices_bdim);
67
+ indices = moveBatchDimToFront (indices, indices_bdim);
68
+ grad = moveBatchDimToFront (grad, grad_bdim);
69
+ grad = ensure_has_bdim (grad, grad_bdim.has_value (), bdim_size);
70
+ const auto range = getStepTensor (indices, bdim_size, num_weights);
71
+ auto result = at::embedding_dense_backward (
72
+ grad, indices + range, num_weights * bdim_size, -1 , scale_grad_by_freq);
73
+ result = reshape_dim_outof (0 , bdim_size, result);
74
+ // Fill in the padding. We can't do it in the embedding_dense_backward call
75
+ // because we need to fill in multiple rows!
76
+ if (padding_idx >= 0 ) {
77
+ result.select (1 , padding_idx).fill_ (0 );
78
+ }
79
+ return std::make_tuple (result, 0 );
80
+ }
81
+
49
82
/* *
50
83
* grid sample batch rule breaks down into 3 cases:
51
84
* case 1 (input is batched, grid is not):
@@ -358,6 +391,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
358
391
EXISTING_BDIM (im2col_backward);
359
392
360
393
VMAP_SUPPORT (" embedding" , embedding_batch_rule);
394
+ VMAP_SUPPORT (" embedding_dense_backward" , embedding_dense_backward_batch_rule);
361
395
362
396
VMAP_SUPPORT (" grid_sampler_2d" , GRID_SAMPLE_BATCH_RULE (grid_sampler));
363
397
VMAP_SUPPORT (" grid_sampler_2d_backward" , GRID_SAMPLE_BW_BATCH_RULE (grid_sampler_2d_backward));
0 commit comments