Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 4038cc4

Browse files
authored
add cosine_similairity batching rule (#171)
* add cosine_similairity batching rule * update test file * update comment * add rule for clamp_min_ and clamp_max_ * update test * update xfail in test_ops * undo line change in BatchRulesLoss
1 parent d1ec060 commit 4038cc4

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
178178
POINTWISE_BOXED(clamp.Tensor);
179179
BINARY_POINTWISE2(clamp_min, Tensor);
180180
UNARY_POINTWISE(clamp_min);
181+
POINTWISE_BOXED(clamp_min_);
181182
BINARY_POINTWISE2(clamp_max, Tensor);
182183
UNARY_POINTWISE(clamp_max);
184+
POINTWISE_BOXED(clamp_max_);
183185

184186
// Commented out so we have a test op
185187
// BINARY_SCALAR_2(copysign, Tensor, Scalar);
@@ -263,6 +265,10 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
263265
m.impl("div_.Scalar", inplacePlumbing1<
264266
DECLTYPE_AUTO(&unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>),
265267
const Scalar&>);
268+
m.impl("clamp_min_.Tensor", inplacePlumbing2<
269+
DECLTYPE_AUTO(&binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>)>);
270+
m.impl("clamp_max_.Tensor", inplacePlumbing2<
271+
DECLTYPE_AUTO(&binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>)>);
266272

267273
m.impl("masked_fill_.Scalar", inplacePlumbing2<
268274
DECLTYPE_AUTO(&binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::masked_fill_, const Scalar&>), const Scalar&>);

functorch/csrc/BatchRulesStopDecomposition.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
182182
STOP_DECOMPOSE(conv_transpose1d);
183183
STOP_DECOMPOSE(conv_transpose3d.input);
184184
STOP_DECOMPOSE(cosine_embedding_loss);
185-
STOP_DECOMPOSE(cosine_similarity);
186185
STOP_DECOMPOSE(ctc_loss.IntList);
187186
STOP_DECOMPOSE(ctc_loss.Tensor);
188187
STOP_DECOMPOSE(cudnn_is_acceptable);

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,6 @@ def test_vmapvjp(self, device, dtype, op):
471471
xfail('vdot'),
472472
xfail('view_as_complex'),
473473
xfail('nanmean'),
474-
xfail('nn.functional.cosine_similarity'),
475474
xfail('nn.functional.layer_norm'),
476475
xfail('nn.functional.nll_loss'),
477476
xfail('block_diag'),

test/test_vmap.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,37 @@ def clone_contiguous(x):
12341234
with self.assertRaisesRegex(RuntimeError, msg):
12351235
vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
12361236

1237+
@parametrize("case",
1238+
(
1239+
(torch.clamp_min_, TensorFactory.randn),
1240+
(torch.clamp_max_, TensorFactory.randn),
1241+
), name_fn=lambda x: x[0].__name__)
1242+
def test_clamp_inplace_variant(self, case):
1243+
test = self._vmap_test
1244+
1245+
def get_number(getter):
1246+
return getter([]).item()
1247+
1248+
op, getter = case
1249+
device = 'cpu'
1250+
B0, B1 = 7, 11
1251+
1252+
# Single vmap: op(Tensor, Tensor)
1253+
test(op, (getter([B0, 3], device), getter([B0, 3], device)), check_propagates_grad=False)
1254+
test(op, (getter([B0], device), getter([B0], device)), check_propagates_grad=False)
1255+
test(op, (getter([2, B0, 3], device), getter([2, B0, 3], device)), in_dims=(1, 1), check_propagates_grad=False)
1256+
test(op, (getter([B0, 2, 3], device), getter([2, B0, 3], device)),
1257+
in_dims=(0, 1), out_dims=1, check_propagates_grad=False)
1258+
test(op, (getter([B0, 2, 3], device), getter([1, 1], device)), in_dims=(0, None), check_propagates_grad=False)
1259+
test(op, (getter([B0, 3], device), getter([B0, 3], device)), in_dims=(0, 0), check_propagates_grad=False)
1260+
1261+
# Nested vmap: op(Tensor, Tensor)
1262+
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)), check_propagates_grad=False)
1263+
1264+
# Python number overload: op(Tensor, Number)
1265+
number = get_number(getter)
1266+
self._test_unary(lambda t: op(t, number), getter, device, check_propagates_grad=False)
1267+
12371268
@parametrize('case', [
12381269
subtest(_make_case(torch.clamp_min), name='clamp_min'),
12391270
subtest(_make_case(torch.clamp_max), name='clamp_max'),
@@ -1255,7 +1286,7 @@ def get_number(getter):
12551286
test(op, (getter([B0], device), getter([2, B0, 3], device)),
12561287
in_dims=(0, 1), out_dims=1)
12571288
test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
1258-
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
1289+
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0))
12591290

12601291
# Nested vmap: op(Tensor, Tensor)
12611292
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
@@ -3069,7 +3100,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30693100
xfail('hstack'),
30703101
xfail('linalg.multi_dot'),
30713102
xfail('nanmean'),
3072-
xfail('nn.functional.cosine_similarity'),
30733103
xfail('nn.functional.layer_norm'),
30743104
xfail('nn.functional.nll_loss'),
30753105
xfail('vstack'),

0 commit comments

Comments
 (0)