Skip to content

Commit b8d9208

Browse files
[ONNX] Update decomposition logic to loop over onnx registry (pytorch#153168)
* [ONNX] Update decomposition logic to loop over onnx registry (pytorch#151826) Fixes pytorch#150367 This PR makes decomposition table from onnx registry, which includes registered ops not only ATen and prim. This will help to keep the custom ops that are specified in the custom_translation table from decomposition during ONNX export. Pull Request resolved: pytorch#151826 Approved by: https://github.com/justinchuby (cherry picked from commit 6cd1741) * [ONNX] Add test for decomp_table update (pytorch#153671) Added a test to strengthen the case for cherry-picking pytorch#153168. The original PR didn’t include this test since the fix for decomp_table and the registry was already covered by existing tests. However, it's reasonable to include a dedicated test for the specific issue (pytorch#150367 ) when considering the cherry-pick. Pull Request resolved: pytorch#153671 Approved by: https://github.com/justinchuby --------- Co-authored-by: titaiwangms <[email protected]>
1 parent 8af995f commit b8d9208

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

test/onnx/exporter/test_api.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,51 @@ def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
359359
self.assertIn("Sub", all_nodes)
360360
self.assertNotIn("Add", all_nodes)
361361

362+
def test_custom_translation_table_supports_custom_op_with_its_decomp(self):
363+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
364+
torch.library.define(
365+
"mylib::foo",
366+
"(Tensor a, Tensor b) -> Tensor",
367+
tags=torch.Tag.pt2_compliant_tag,
368+
lib=lib,
369+
)
370+
371+
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
372+
@torch.library.register_fake("mylib::foo")
373+
def foo_impl(a, b):
374+
return a + b
375+
376+
class M(torch.nn.Module):
377+
def forward(self, x, y):
378+
return torch.ops.mylib.foo(x, y)
379+
380+
def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT:
381+
# Replace add with Sub
382+
return op.Sub(self, other)
383+
384+
# With the custom op defined, we can use it in the model
385+
# and replace it with a custom translation table
386+
custom_translation_table = {
387+
torch.ops.mylib.foo.default: onnx_add,
388+
}
389+
onnx_program = torch.onnx.export(
390+
M(),
391+
(torch.ones(3, 3), torch.ones(3, 3)),
392+
custom_translation_table=custom_translation_table,
393+
dynamo=True,
394+
)
395+
all_nodes = [n.op_type for n in onnx_program.model.graph]
396+
self.assertIn("Sub", all_nodes)
397+
self.assertNotIn("Add", all_nodes)
398+
399+
# Without the custom op defined, it's going to be decomposed
400+
onnx_program_decomp = torch.onnx.export(
401+
M(), (torch.ones(3, 3), torch.ones(3, 3)), dynamo=True
402+
)
403+
all_nodes_decomp = [n.op_type for n in onnx_program_decomp.model.graph]
404+
self.assertIn("Add", all_nodes_decomp)
405+
self.assertNotIn("Sub", all_nodes_decomp)
406+
362407

363408
class TestFakeTensorExport(common_utils.TestCase):
364409
"""Test exporting in fake mode."""

torch/onnx/_internal/exporter/_decomp.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def get_onnx_implemented_overloads(
1616
registry: _registration.ONNXRegistry,
17-
) -> list[torch._ops.OperatorBase]:
17+
) -> list[_registration.TorchOp]:
1818
"""
1919
Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations.
2020
@@ -24,24 +24,19 @@ def get_onnx_implemented_overloads(
2424
Returns:
2525
A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations.
2626
"""
27-
registered_ops: list[torch._ops.OperatorBase] = []
28-
for op_namespace in (torch.ops.aten, torch.ops.prims):
29-
op_names = dir(op_namespace)
30-
for op_name in op_names:
31-
op_overload_packet = getattr(op_namespace, op_name)
32-
if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
33-
continue
34-
35-
for overload_name in op_overload_packet.overloads():
36-
op_overload = getattr(op_overload_packet, overload_name)
37-
if registry.is_registered(op_overload):
38-
registered_ops.append(op_overload)
27+
registered_ops: list[_registration.TorchOp] = []
28+
for onnx_decomp_meta in registry.functions.values():
29+
assert len(onnx_decomp_meta) > 0
30+
# Different OnnxDecompMeta for the same TorchOp should
31+
# have the same fx_target.
32+
fx_target = onnx_decomp_meta[0].fx_target
33+
registered_ops.append(fx_target)
3934
return registered_ops
4035

4136

4237
def create_onnx_friendly_decomposition_table(
43-
onnx_registered_ops: set[torch._ops.OperatorBase],
44-
) -> dict[torch._ops.OperatorBase, Callable]:
38+
onnx_registered_ops: set[_registration.TorchOp],
39+
) -> dict[_registration.TorchOp, Callable]:
4540
"""
4641
This function creates a dictionary of op overloads and their decomposition functions
4742
for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function,
@@ -55,7 +50,7 @@ def create_onnx_friendly_decomposition_table(
5550
Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding
5651
decomposition functions.
5752
"""
58-
decomposition_table: dict[torch._ops.OperatorBase, Callable] = {}
53+
decomposition_table: dict[_registration.TorchOp, Callable] = {}
5954

6055
for op_overload, decomp_fn in itertools.chain(
6156
torch.export.default_decompositions().items(), # type: ignore[attr-defined]

0 commit comments

Comments
 (0)