Skip to content

Commit 3ba93d3

Browse files
authored
Added cdist forward/backward batching rules (#306)
* WIP on adding cdist batching rules * Updated cdist forward / backward batch rules * Fixed code according to the review - rewrote forward pass reusing BINARY_POINTWISE with an update - rewrote backward pass + comments * Restore previous code as cdist issue has been fixed * Added comment about type promotion for cdist
1 parent 7fa79f9 commit 3ba93d3

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& sec
2222
}
2323
}
2424

25-
template <typename F, F Func, typename... ExtraArgs>
26-
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
25+
std::tuple<Tensor, Tensor> _binary_pointwise_helper(
2726
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
28-
const Tensor& other, optional<int64_t> other_batch_dim,
29-
ExtraArgs... extra_args) {
27+
const Tensor& other, optional<int64_t> other_batch_dim) {
3028
// compute max logical rank
3129
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
3230
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
@@ -52,8 +50,22 @@ std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
5250
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
5351
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
5452

53+
return std::make_tuple(tensor_, other_);
54+
}
55+
56+
template <typename F, F Func, typename... ExtraArgs>
57+
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
58+
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
59+
const Tensor& other, optional<int64_t> other_batch_dim,
60+
ExtraArgs... extra_args) {
61+
62+
auto tensor_other = _binary_pointwise_helper(
63+
tensor, tensor_batch_dim, other, other_batch_dim);
64+
auto tensor_ = std::get<0>(tensor_other);
65+
auto other_ = std::get<1>(tensor_other);
66+
5567
auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
56-
return std::make_tuple( std::move(result), 0 );
68+
return std::make_tuple(result, 0);
5769
}
5870

5971
template <typename A, A a, typename C>
@@ -163,6 +175,52 @@ Tensor addr_decomposition(
163175
return self * beta + outer;
164176
}
165177

178+
std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
179+
const Tensor& grad, optional<int64_t> grad_bdim,
180+
const Tensor& x1, optional<int64_t> x1_bdim,
181+
const Tensor& x2, optional<int64_t> x2_bdim,
182+
const double p,
183+
const Tensor& cdist, optional<int64_t> cdist_bdim) {
184+
185+
auto x1_ = x1;
186+
if (cdist_bdim && !x1_bdim) {
187+
// We need to make sure that x1 has batch dim if cdist has one
188+
// otherwise, we get
189+
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
190+
// but expected shape compatible with [4, 5]
191+
auto bs = cdist.size(*cdist_bdim);
192+
x1_ = ensure_has_bdim(x1, false, bs);
193+
x1_ = x1_.contiguous();
194+
x1_bdim = 0;
195+
}
196+
197+
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
198+
// _binary_pointwise_batch_rule
199+
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
200+
x1_ = std::get<0>(x12);
201+
auto x2_ = std::get<1>(x12);
202+
203+
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
204+
if ((x1_bdim || x2_bdim) && !grad_bdim) {
205+
// We need to make sure that grad has batch dim if x1 or x2 have one
206+
// Probably, there is an assumption on the strides.
207+
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
208+
auto bs = get_bdim_size2(x1_, 0, x2_, 0);
209+
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
210+
grad_ = grad_.contiguous();
211+
}
212+
213+
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
214+
215+
optional<int64_t> out_bdim = nullopt;
216+
if (x1_bdim || x2_bdim) {
217+
out_bdim = 0;
218+
}
219+
220+
return std::make_tuple(out, out_bdim);
221+
}
222+
223+
166224
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
167225
#define BINARY_POINTWISE2(op, overload) \
168226
VMAP_SUPPORT(#op"."#overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
@@ -218,6 +276,12 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
218276
UNARY_POINTWISE(clamp_max);
219277
POINTWISE_BOXED(clamp_max_);
220278

279+
VARIADIC_BDIMS_BOXED(_euclidean_dist);
280+
// Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
281+
// but cdist can't work with scalars, at least 2d tensors.
282+
BINARY_POINTWISE(_cdist_forward);
283+
VMAP_SUPPORT("_cdist_backward", cdist_backward_batch_rule);
284+
221285
// Commented out so we have a test op
222286
// BINARY_SCALAR_2(copysign, Tensor, Scalar);
223287
BINARY_SCALAR_2(div, Tensor, Scalar);

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
127127
OP_DECOMPOSE2(bitwise_xor, Scalar);
128128
OP_DECOMPOSE(broadcast_tensors);
129129
OP_DECOMPOSE(broadcast_to);
130+
OP_DECOMPOSE(cdist);
130131
OP_DECOMPOSE(clip);
131132
OP_DECOMPOSE2(clip, Tensor );
132133
OP_DECOMPOSE(concat);

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,6 @@ def test_vmapjvpall(self, device, dtype, op):
745745
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
746746
xfail('view_as_complex'),
747747
xfail('__getitem__'),
748-
xfail('cdist'),
749748
xfail('cholesky'),
750749
xfail('complex'),
751750
xfail('copysign'),
@@ -757,7 +756,6 @@ def test_vmapjvpall(self, device, dtype, op):
757756
xfail('fft.ihfft'),
758757
xfail('fft.rfft'),
759758
xfail('fft.rfftn'),
760-
xfail('cdist'),
761759
xfail('fill_'),
762760
xfail('fmax'),
763761
xfail('fmin'),

test/test_vmap.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3170,7 +3170,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31703170

31713171
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
31723172
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
3173-
xfail('cdist'),
31743173
xfail('complex'),
31753174
xfail('copysign'),
31763175
xfail('dsplit'),

0 commit comments

Comments
 (0)