Skip to content

Commit 6fa0237

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add util to find functional and/or out variant of an operator (#250)
Summary: Pull Request resolved: #250 This is a useful API to find functional or out variant from a given operator. It looks for all the overload of an operator name and see if the target variant exists in all the overloads. Reviewed By: manuelcandales Differential Revision: D49070398 fbshipit-source-id: db06e7d3be3214adc83227b3606329043e8fea08
1 parent 7e6b2b1 commit 6fa0237

File tree

6 files changed

+135
-59
lines changed

6 files changed

+135
-59
lines changed

exir/dialects/edge/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ python_library(
1717
"//caffe2:torch",
1818
"//caffe2/torchgen:torchgen",
1919
"//executorch/exir/dialects/edge/dtype:lib",
20+
"//executorch/exir/dialects/edge/op:lib",
2021
"//executorch/exir/dialects/edge/spec:lib",
21-
"//executorch/exir/operator:convert",
2222
],
2323
)
2424

exir/dialects/edge/_ops.py

Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import dataclasses
8-
import logging
97
from typing import Any, Dict, List, Optional, Set, Union
108

119
import pkg_resources
1210

1311
import torch
1412

1513
from executorch.exir.dialects.edge.dtype.supported import regular_tensor_str_to_dtypes
16-
from executorch.exir.dialects.edge.spec.utils import (
17-
get_tensor_variable_names,
18-
get_torch_op_overload,
19-
)
20-
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
14+
from executorch.exir.dialects.edge.op.api import to_variant
15+
from executorch.exir.dialects.edge.spec.utils import get_tensor_variable_names
2116

2217
# pyre-ignore
2318
from ruamel.yaml import YAML
24-
from torchgen.model import FunctionSchema
19+
from torchgen.model import SchemaKind
2520

2621

2722
class AllowedDtypeSet:
@@ -324,53 +319,9 @@ def to_out_variant(self) -> torch._ops.OpOverload:
324319
# return if already found
325320
if "_out_variant" in self.__dict__ and self._out_variant:
326321
return self._out_variant
327-
# first check if the current operator is an out-variant
328-
native_schema: Optional[FunctionSchema] = _pybind_schema_to_native_schema(
329-
self._schema.schema
330-
)
331-
assert (
332-
native_schema is not None
333-
), f"Schema: {self._schema} cannot be converted to torch.FunctionSchema"
334-
if native_schema.is_out_fn():
335-
out = get_torch_op_overload(
336-
self.namespace,
337-
self._schema.name.split("::")[1],
338-
self._schema.overload_name,
339-
)
340-
self._out_variant = out
341-
return out
342-
# get all overloads
343-
torch_packet = getattr(
344-
getattr(torch.ops, self.namespace), self._schema.name.split("::")[1]
345-
)
346-
schemas: List[torch._C.FunctionSchema] = [
347-
getattr(torch_packet, o)._schema
348-
for o in torch._C._jit_get_operation(self._schema.name)[1]
349-
]
350-
# compare the signature of out variant overload with the signature of the original overload
351-
signature = dataclasses.replace(native_schema.signature(), returns=())
352-
for schema in schemas:
353-
# ignore self
354-
if str(schema) == str(self._schema):
355-
continue
356-
native_s: Optional[FunctionSchema] = _pybind_schema_to_native_schema(schema)
357-
if native_s is None:
358-
logging.warning(
359-
f"Schema: {schema} cannot be converted to torch.FunctionSchema"
360-
)
361-
continue
362-
if (
363-
native_s.is_out_fn()
364-
and dataclasses.replace(native_s.signature(), returns=()) == signature
365-
):
366-
out = get_torch_op_overload(
367-
self.namespace, schema.name.split("::")[1], schema.overload_name
368-
)
369-
self._out_variant = out
370-
return out
371-
raise RuntimeError(
372-
f"Out variant of operator {self.name()} can't be found. We've found the schemas of all the overloads: {[str(s) for s in schemas]}"
373-
)
322+
out_variant = to_variant(self._op, SchemaKind.out)
323+
self._out_variant = out_variant
324+
return out_variant
374325

375326
def __getattr__(self, name):
376327
if name == "_schema":

exir/dialects/edge/op/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
23

34
oncall("ai_infra_mobile_platform")
45

@@ -9,5 +10,19 @@ python_library(
910
],
1011
deps = [
1112
"//caffe2:torch",
13+
"//caffe2/torchgen:torchgen",
14+
"//executorch/exir/operator:convert",
15+
],
16+
)
17+
18+
python_unittest(
19+
name = "test_api",
20+
srcs = [
21+
"test/test_api.py",
22+
],
23+
deps = [
24+
":lib",
25+
"//caffe2:torch",
26+
"//caffe2/torchgen:torchgen",
1227
],
1328
)

exir/dialects/edge/op/api.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
"""
88
APIs to help lowering edge dialect ops to other dialects.
99
"""
10-
from typing import Optional
10+
import dataclasses
11+
import logging
12+
from typing import List, Optional
1113

1214
import torch
13-
from torch._ops import OpOverloadPacket
15+
16+
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
17+
from torch._ops import OpOverload, OpOverloadPacket
18+
from torchgen.model import FunctionSchema, SchemaKind
1419

1520

1621
def get_torch_op_overload(
@@ -26,3 +31,56 @@ def get_torch_op_overload(
2631
def get_callable(name):
2732
main, suffix = name.split(".")
2833
return get_torch_op_overload("aten", main, suffix)
34+
35+
36+
def to_variant(op: OpOverload, variant: SchemaKind) -> OpOverload:
37+
"""Given an operator overload, return its corresponding variant. Currently
38+
only supports functional variant and out variant.
39+
Argument:
40+
op (OpOverload): operator overload instance.
41+
variant (SchemaKind): the variant we are looking for.
42+
Returns:
43+
OpOverload: The matched variant operator.
44+
Example:
45+
torch.ops.aten.add.Tensor, SchemaKind.out -> torch.ops.aten.add.out
46+
torch.ops.aten.add.out, SchemaKind.functional -> torch.ops.aten.add.Tensor
47+
"""
48+
assert (
49+
variant == SchemaKind.functional or variant == SchemaKind.out
50+
), f"Only support out variant and functional variant, got {variant}"
51+
# first check if the current operator is the target variant
52+
native_schema: Optional[FunctionSchema] = _pybind_schema_to_native_schema(
53+
op._schema
54+
)
55+
assert (
56+
native_schema is not None
57+
), f"Schema: {op._schema} cannot be converted to torch.FunctionSchema"
58+
59+
# get all overloads
60+
torch_packet = getattr(
61+
getattr(torch.ops, op.namespace), op._schema.name.split("::")[1]
62+
)
63+
schemas: List[torch._C.FunctionSchema] = [
64+
getattr(torch_packet, o)._schema
65+
for o in torch._C._jit_get_operation(op._schema.name)[1]
66+
]
67+
# compare the signature of out variant overload with the signature of the original overload
68+
signature = dataclasses.replace(native_schema.signature(), returns=())
69+
for schema in schemas:
70+
native_s: Optional[FunctionSchema] = _pybind_schema_to_native_schema(schema)
71+
if native_s is None:
72+
logging.warning(
73+
f"Schema: {schema} cannot be converted to torch.FunctionSchema"
74+
)
75+
continue
76+
if (
77+
native_s.kind() == variant
78+
and dataclasses.replace(native_s.signature(), returns=()) == signature
79+
):
80+
op_variant = get_torch_op_overload(
81+
op.namespace, schema.name.split("::")[1], schema.overload_name
82+
)
83+
return op_variant
84+
raise RuntimeError(
85+
f"{variant} variant of operator {op.name()} can't be found. We've found the schemas of all the overloads: {[str(s) for s in schemas]}"
86+
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env fbpython
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import torch
11+
12+
from executorch.exir.dialects.edge.op.api import to_variant
13+
from torchgen.model import SchemaKind
14+
15+
aten = torch.ops.aten
16+
17+
OPS_TO_FUNCTIONAL = {
18+
aten.add.out: aten.add.Tensor,
19+
aten._native_batch_norm_legit_no_training.out: aten._native_batch_norm_legit_no_training.default,
20+
aten.addmm.out: aten.addmm.default,
21+
aten.view_copy.out: aten.view_copy.default,
22+
}
23+
24+
25+
class TestApi(unittest.TestCase):
26+
"""Test api.py"""
27+
28+
def test_to_out_variant_returns_self_when_given_out_variant(self) -> None:
29+
op = aten.add.out
30+
variant = to_variant(op, SchemaKind.out)
31+
self.assertEqual(variant, op)
32+
33+
def test_to_functional_variant_returns_self_when_given_functional(self) -> None:
34+
op = aten.leaky_relu.default
35+
variant = to_variant(op, SchemaKind.functional)
36+
self.assertEqual(variant, op)
37+
38+
def test_to_functional_variant_returns_correct_op(
39+
self,
40+
) -> None:
41+
for op in OPS_TO_FUNCTIONAL:
42+
variant = to_variant(op, SchemaKind.functional)
43+
self.assertEqual(variant, OPS_TO_FUNCTIONAL[op])
44+
45+
def test_to_out_variant_returns_correct_op(
46+
self,
47+
) -> None:
48+
inv_map = {v: k for k, v in OPS_TO_FUNCTIONAL.items()}
49+
for op in inv_map:
50+
variant = to_variant(op, SchemaKind.out)
51+
self.assertEqual(variant, inv_map[op])

exir/dialects/edge/test/test_edge_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ def test_to_out_variant_returns_correct_op(self) -> None:
417417
def test_to_out_variant_raises_exception_when_no_out_variant(self) -> None:
418418
view_op = ops.edge.aten.view.default
419419
with self.assertRaisesRegex(
420-
RuntimeError, "Out variant of operator aten::view can't be found"
420+
RuntimeError,
421+
"SchemaKind.out variant of operator aten::view can't be found.",
421422
):
422423
view_op.to_out_variant()
423424

0 commit comments

Comments
 (0)