@@ -22,11 +22,9 @@ static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& sec
22
22
}
23
23
}
24
24
25
- template <typename F, F Func, typename ... ExtraArgs>
26
- std::tuple<Tensor,optional<int64_t >> _binary_pointwise_batch_rule (
25
+ std::tuple<Tensor, Tensor> _binary_pointwise_helper (
27
26
const Tensor& tensor, optional<int64_t > tensor_batch_dim,
28
- const Tensor& other, optional<int64_t > other_batch_dim,
29
- ExtraArgs... extra_args) {
27
+ const Tensor& other, optional<int64_t > other_batch_dim) {
30
28
// compute max logical rank
31
29
auto tensor_logical_rank = rankWithoutBatchDim (tensor, tensor_batch_dim);
32
30
auto other_logical_rank = rankWithoutBatchDim (other, other_batch_dim);
@@ -52,8 +50,22 @@ std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
52
50
tensor_ = maybePadToLogicalRank (tensor_, tensor_batch_dim, max_logical_rank);
53
51
other_ = maybePadToLogicalRank (other_, other_batch_dim, max_logical_rank);
54
52
53
+ return std::make_tuple (tensor_, other_);
54
+ }
55
+
56
+ template <typename F, F Func, typename ... ExtraArgs>
57
+ std::tuple<Tensor,optional<int64_t >> _binary_pointwise_batch_rule (
58
+ const Tensor& tensor, optional<int64_t > tensor_batch_dim,
59
+ const Tensor& other, optional<int64_t > other_batch_dim,
60
+ ExtraArgs... extra_args) {
61
+
62
+ auto tensor_other = _binary_pointwise_helper (
63
+ tensor, tensor_batch_dim, other, other_batch_dim);
64
+ auto tensor_ = std::get<0 >(tensor_other);
65
+ auto other_ = std::get<1 >(tensor_other);
66
+
55
67
auto result = Func (tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
56
- return std::make_tuple ( std::move ( result) , 0 );
68
+ return std::make_tuple (result, 0 );
57
69
}
58
70
59
71
template <typename A, A a, typename C>
@@ -163,6 +175,52 @@ Tensor addr_decomposition(
163
175
return self * beta + outer;
164
176
}
165
177
178
+ std::tuple<Tensor,optional<int64_t >> cdist_backward_batch_rule (
179
+ const Tensor& grad, optional<int64_t > grad_bdim,
180
+ const Tensor& x1, optional<int64_t > x1_bdim,
181
+ const Tensor& x2, optional<int64_t > x2_bdim,
182
+ const double p,
183
+ const Tensor& cdist, optional<int64_t > cdist_bdim) {
184
+
185
+ auto x1_ = x1;
186
+ if (cdist_bdim && !x1_bdim) {
187
+ // We need to make sure that x1 has batch dim if cdist has one
188
+ // otherwise, we get
189
+ // RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
190
+ // but expected shape compatible with [4, 5]
191
+ auto bs = cdist.size (*cdist_bdim);
192
+ x1_ = ensure_has_bdim (x1, false , bs);
193
+ x1_ = x1_.contiguous ();
194
+ x1_bdim = 0 ;
195
+ }
196
+
197
+ // We need to apply the same preprocessing on x1 and x2 as in the forward pass
198
+ // _binary_pointwise_batch_rule
199
+ auto x12 = _binary_pointwise_helper (x1_, x1_bdim, x2, x2_bdim);
200
+ x1_ = std::get<0 >(x12);
201
+ auto x2_ = std::get<1 >(x12);
202
+
203
+ auto grad_ = moveBatchDimToFront (grad, grad_bdim);
204
+ if ((x1_bdim || x2_bdim) && !grad_bdim) {
205
+ // We need to make sure that grad has batch dim if x1 or x2 have one
206
+ // Probably, there is an assumption on the strides.
207
+ // Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
208
+ auto bs = get_bdim_size2 (x1_, 0 , x2_, 0 );
209
+ grad_ = ensure_has_bdim (grad_, grad_bdim.has_value (), bs);
210
+ grad_ = grad_.contiguous ();
211
+ }
212
+
213
+ auto out = at::_cdist_backward (grad_, x1_, x2_, p, cdist);
214
+
215
+ optional<int64_t > out_bdim = nullopt;
216
+ if (x1_bdim || x2_bdim) {
217
+ out_bdim = 0 ;
218
+ }
219
+
220
+ return std::make_tuple (out, out_bdim);
221
+ }
222
+
223
+
166
224
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
167
225
#define BINARY_POINTWISE2 (op, overload ) \
168
226
VMAP_SUPPORT (#op" ." #overload, BINARY_POINTWISE_BATCH_RULE (ATEN_FN2 (op, overload)));
@@ -218,6 +276,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
218
276
UNARY_POINTWISE (clamp_max);
219
277
POINTWISE_BOXED (clamp_max_);
220
278
279
+ VARIADIC_BDIMS_BOXED (_euclidean_dist);
280
+ // Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
281
+ // but cdist can't work with scalars, at least 2d tensors.
282
+ BINARY_POINTWISE (_cdist_forward);
283
+ VMAP_SUPPORT (" _cdist_backward" , cdist_backward_batch_rule);
284
+
221
285
// Commented out so we have a test op
222
286
// BINARY_SCALAR_2(copysign, Tensor, Scalar);
223
287
BINARY_SCALAR_2 (div, Tensor, Scalar);
0 commit comments