Skip to content

Commit 21b2394

Browse files
authored
Replace addr_decomp with the one in PyTorch Core, issue #833 (#836)
* Replace addr_decomp with the one in PyTorch Core, issue #833 * only allow OpOverload to be passed to _register_jit_decomposition_bypass_script * change variable name to vmap_decompositions_lib * change function name to _register_python_decomposition_vmap
1 parent 76976db commit 21b2394

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

functorch/_src/eager_transforms.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,20 @@ def get_function_def(sig):
13051305
torch.jit._register_decomposition(decomp, graph)
13061306

13071307

1308+
# use an alternate way to register an operator into the decomposition table
1309+
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
1310+
# because the Tensor types generated cannot be unioned by torchscript
1311+
# decomp should be type OpOverload
1312+
vmap_decompositions_lib = torch.library.Library("aten", "IMPL", "FuncTorchBatched")
1313+
1314+
1315+
def _register_python_decomposition_vmap(decomp):
1316+
if decomp in decomposition_table:
1317+
vmap_decompositions_lib.impl(decomp, decomposition_table[decomp])
1318+
else:
1319+
raise RuntimeError(f"could not find decomposition for {decomp}")
1320+
1321+
13081322
_register_jit_decomposition(torch.ops.aten.trace.default)
13091323
_register_jit_decomposition(torch.ops.aten.nll_loss_backward.default)
13101324
_register_jit_decomposition(torch.ops.aten.nll_loss2d_backward.default)
@@ -1316,3 +1330,4 @@ def get_function_def(sig):
13161330
_register_jit_decomposition(torch.ops.aten.binary_cross_entropy_backward.default)
13171331
_register_jit_decomposition(torch.ops.aten.binary_cross_entropy.default)
13181332
_register_jit_decomposition(torch.ops.aten.native_layer_norm_backward.default)
1333+
_register_python_decomposition_vmap(torch.ops.aten.addr.default)

functorch/csrc/BatchRulesBinaryOps.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,6 @@ std::tuple<Tensor,optional<int64_t>> masked_select_backward_batch_rule(
264264
return std::make_tuple(result, 0);
265265
}
266266

267-
Tensor addr_decomposition(
268-
const Tensor& self, const Tensor& vec1, const Tensor& vec2,
269-
const Scalar& beta, const Scalar& alpha) {
270-
271-
auto outer = alpha * vec1.unsqueeze(-1) * vec2.unsqueeze(-2);
272-
return self * beta + outer;
273-
}
274-
275267
std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
276268
const Tensor& grad, optional<int64_t> grad_bdim,
277269
const Tensor& x1, optional<int64_t> x1_bdim,
@@ -369,7 +361,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
369361
BINARY_SCALAR_2(add, Tensor, Scalar);
370362
POINTWISE_BOXED(addcdiv);
371363
POINTWISE_BOXED(addcmul);
372-
m.impl("addr", addr_decomposition);
373364
BINARY_POINTWISE(atan2);
374365
BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
375366
BINARY_POINTWISE2(bitwise_or, Tensor);

0 commit comments

Comments
 (0)