Skip to content

Commit f4b5647

Browse files
committed
[serialize] Support TopK operator
Let's compile torch.topk to circle TOP_K_V2 operator. TICO-DCO-1.0-Signed-off-by: Dayoung Lee <[email protected]>
1 parent f848310 commit f4b5647

File tree

9 files changed

+284
-26
lines changed

9 files changed

+284
-26
lines changed

test/modules/op/to.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,14 @@ def forward(self, x):
8585

8686
def get_example_inputs(self):
8787
return (torch.randn(1, 3),), {}
88+
89+
90+
class SimpleToForCast(TestModuleBase):
91+
def __init__(self):
92+
super().__init__()
93+
94+
def forward(self, x):
95+
return x.to(torch.int32)
96+
97+
def get_example_inputs(self):
98+
return (torch.randn(1, 3),), {}

test/modules/op/top_k.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from test.modules.base import TestModuleBase
18+
from test.utils.tag import use_onert
19+
20+
# luci-interpreter doesn't support TopK operator yet
21+
@use_onert
22+
class SimpleTopK(TestModuleBase):
23+
def __init__(self):
24+
super().__init__()
25+
26+
def forward(self, x):
27+
values, indices = torch.topk(x, 2)
28+
return values, indices
29+
30+
def get_example_inputs(self):
31+
batch_size = 1
32+
seq_len = 63
33+
num_experts = 8
34+
return (torch.randn(batch_size * seq_len, num_experts),), {}

tico/passes/legalize_predefined_layout_operators.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
if TYPE_CHECKING:
1919
import torch.fx
20+
from operator import getitem
21+
2022
import torch
2123
from torch.export import ExportedProgram
2224

@@ -26,7 +28,7 @@
2628
from tico.utils.graph import create_node
2729
from tico.utils.passes import PassBase, PassResult
2830
from tico.utils.trace_decorators import trace_graph_diff_on_pass
29-
from tico.utils.utils import is_target_node
31+
from tico.utils.utils import is_target_node, set_new_meta_val
3032
from tico.utils.validate_args_kwargs import (
3133
AvgPool2dArgs,
3234
Conv2DArgs,
@@ -35,6 +37,7 @@
3537
DequantizePerTensorArgs,
3638
InstanceNormArgs,
3739
MaxPool2dWithIndicesArgs,
40+
TopKArgs,
3841
)
3942

4043

@@ -434,6 +437,49 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool:
434437
modified = True
435438
return modified
436439

440+
def legalize_top_k(self, exported_program, node) -> bool:
441+
logger = logging.getLogger(__name__)
442+
modified = False
443+
444+
graph_module = exported_program.graph_module
445+
graph = graph_module.graph
446+
447+
args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
448+
input_ = args.input
449+
k = args.k
450+
dim = args.dim
451+
452+
if not (dim == -1 or dim == len(extract_shape(input_)) - 1):
453+
raise NotYetSupportedError("Only support dim = -1 (last dimension)")
454+
455+
with graph.inserting_after(input_):
456+
circle_topk = create_node(
457+
graph,
458+
torch.ops.circle_custom.top_k,
459+
args=(input_, k),
460+
origin=input_,
461+
)
462+
463+
with graph.inserting_after(circle_topk):
464+
topk_values = create_node(graph, getitem, args=(circle_topk, 0))
465+
topk_indices = create_node(graph, getitem, args=(circle_topk, 1))
466+
with graph.inserting_after(topk_indices):
467+
topk_indices_int64 = create_node(
468+
graph,
469+
torch.ops.aten._to_copy.default,
470+
args=(topk_indices,),
471+
kwargs={"dtype": torch.int64},
472+
)
473+
474+
get_item, get_item_1 = node.users.keys()
475+
get_item.replace_all_uses_with(topk_values, propagate_meta=True)
476+
get_item_1.replace_all_uses_with(topk_indices_int64, propagate_meta=True)
477+
478+
logger.debug(f"{node.name} is replaced with {circle_topk.name}")
479+
modified = True
480+
481+
return modified
482+
437483
def call(self, exported_program: ExportedProgram) -> PassResult:
438484
target_to_legalize_func = {
439485
torch.ops.aten.conv2d.default: self.legalize_conv2d,
@@ -442,6 +488,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
442488
torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
443489
torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
444490
torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
491+
torch.ops.aten.topk.default: self.legalize_top_k,
445492
}
446493

447494
graph_module = exported_program.graph_module

tico/serialize/circle_serializer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
multiple_output_ops = [
3333
torch.ops.aten.split_with_sizes.default,
3434
torch.ops.aten.max.dim,
35+
torch.ops.aten.topk.default,
36+
torch.ops.circle_custom.top_k,
3537
]
3638

3739

tico/serialize/operators/op_to_copy.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, TYPE_CHECKING
15+
from typing import Dict, List, TYPE_CHECKING, Union
1616

1717
if TYPE_CHECKING:
1818
import torch._ops
1919
import torch.fx
2020
import torch
2121
from circle_schema import circle
2222

23+
from tico.passes import ops
24+
2325
from tico.serialize.circle_mapping import (
2426
extract_circle_dtype,
2527
extract_torch_dtype,
@@ -29,12 +31,12 @@
2931
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
3032
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
3133
from tico.utils.errors import NotYetSupportedError
32-
from tico.utils.validate_args_kwargs import ToCopyArgs
34+
from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs
3335

3436

3537
@register_node_visitor
3638
class ToCopyVisitor(NodeVisitor):
37-
target: List[torch._ops.OpOverload] = [torch.ops.aten._to_copy.default]
39+
target: List[torch._ops.OpOverload] = ops.aten.to_copy
3840

3941
def __init__(self, op_codes: Dict[OpCode, int], graph):
4042
super().__init__(op_codes, graph)
@@ -60,42 +62,55 @@ def define_cast_node(
6062

6163
return operator
6264

65+
def parse_args(self, op: torch._ops.OpOverload, args, kwargs):
66+
ret: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs]
67+
if op is torch.ops.aten._to_copy.default:
68+
ret = ToCopyArgs(*args, **kwargs)
69+
elif op is torch.ops.aten.to.dtype:
70+
ret = ToDtypeArgs(*args, **kwargs)
71+
elif op is torch.ops.aten.to.dtype_layout:
72+
ret = ToDtypeLayoutArgs(*args, **kwargs)
73+
else:
74+
raise NotImplementedError(f"Unsupported to_copy/to operator: {op}")
75+
76+
return ret
77+
6378
def define_node(
6479
self,
6580
node: torch.fx.Node,
6681
) -> circle.Operator.OperatorT:
67-
supported_kwargs = ["dtype", "device", "layout"]
68-
if not all(k in supported_kwargs for k in node.kwargs):
69-
unsupported_node_kargs = list(node.kwargs.keys())
70-
for supported_key in supported_kwargs:
71-
if supported_key in node.kwargs:
72-
unsupported_node_kargs.remove(supported_key)
73-
raise NotYetSupportedError(
74-
f"Support only {supported_kwargs} kwargs now. Do not support {unsupported_node_kargs}"
75-
)
76-
77-
args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
82+
args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
7883
input = args.input
7984
dtype = args.dtype
85+
layout = args.layout
86+
# device is meaningless in circle
87+
88+
pin_memory = args.pin_memory
89+
non_blocking = args.non_blocking
90+
memory_format = args.memory_format
91+
92+
if pin_memory is not None:
93+
raise NotYetSupportedError("Do not support pin_memory yet")
94+
if non_blocking is True:
95+
raise NotYetSupportedError("Do not support non_blocking yet")
96+
if memory_format is not None:
97+
raise NotYetSupportedError("Do not support memory_format yet")
8098

8199
input_meta = input.meta["val"]
82100
# https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
83101
# layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
84102
if "layout" in input.kwargs and input.kwargs["layout"] != input_meta:
85103
raise NotYetSupportedError(
86-
f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {node.kwargs['layout']})."
104+
f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {layout})."
87105
)
88106

89-
if dtype is not None:
90-
target_type = node.kwargs["dtype"]
91-
else:
92-
# device and layout are meaningless
93-
target_type = extract_torch_dtype(node)
94-
assert isinstance(target_type, torch.dtype), type(target_type)
107+
if dtype is None:
108+
dtype = extract_torch_dtype(node)
109+
assert isinstance(dtype, torch.dtype), type(dtype)
95110

96111
# define cast node
97112
in_type: int = extract_circle_dtype(input)
98-
out_type: int = to_circle_dtype(target_type)
113+
out_type: int = to_circle_dtype(dtype)
99114
inputs = [input]
100115
outputs = [node]
101116
operator = self.define_cast_node(inputs, outputs, in_type, out_type)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import torch
20+
from circle_schema import circle
21+
22+
from tico.serialize.circle_graph import CircleSubgraph
23+
from tico.serialize.circle_mapping import (
24+
circle_legalize_dtype_to,
25+
extract_circle_shape,
26+
extract_shape,
27+
extract_torch_dtype,
28+
)
29+
from tico.serialize.operators.hashable_opcode import OpCode
30+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
31+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
32+
from tico.utils.validate_args_kwargs import TopKArgs
33+
34+
35+
@register_node_visitor
36+
class TopkVisitor(NodeVisitor):
37+
""" """
38+
39+
target: List[torch._ops.OpOverload] = [
40+
torch.ops.circle_custom.top_k,
41+
]
42+
43+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
44+
super().__init__(op_codes, graph)
45+
46+
def define_topk_node(
47+
self, inputs: List, outputs: List
48+
) -> circle.Operator.OperatorT:
49+
op_index = get_op_index(
50+
circle.BuiltinOperator.BuiltinOperator.TOPK_V2, self._op_codes
51+
)
52+
53+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54+
55+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.TopKV2Options
56+
option = circle.TopKV2Options.TopKV2OptionsT()
57+
operator.builtinOptions = option
58+
59+
return operator
60+
61+
def define_node(
62+
self,
63+
node: torch.fx.Node,
64+
) -> circle.Operator.OperatorT:
65+
args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
66+
input = args.input
67+
k = args.k
68+
69+
input_shape = extract_circle_shape(input)
70+
k_i32 = circle_legalize_dtype_to(k, dtype=torch.int32)
71+
assert args.dim == -1 or args.dim == len(input_shape) - 1
72+
73+
inputs = [input, k_i32]
74+
75+
outputs = [i for i in node.users.keys()]
76+
77+
topk_node: circle.Operator.OperatorT = self.define_topk_node(inputs, outputs)
78+
79+
return topk_node

tico/utils/graph.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19-
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
19+
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
2020

2121
if TYPE_CHECKING:
2222
import torch.fx
23+
from operator import getitem
24+
2325
import torch
2426
from torch.export import ExportedProgram
2527
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
@@ -238,7 +240,7 @@ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str:
238240

239241
def create_node(
240242
graph: torch.fx.Graph,
241-
target: torch._ops.OpOverload,
243+
target: Callable,
242244
args: Optional[Tuple[Any, ...]] = None,
243245
kwargs: Optional[Dict[str, Any]] = None,
244246
*,
@@ -252,7 +254,7 @@ def create_node(
252254
graph : torch.fx.Graph
253255
The graph that will own the newly-created node.
254256
255-
target : torch._ops.OpOverload
257+
target : Callable
256258
The op to call (e.g. `torch.add` or "call_function" target).
257259
258260
args : Tuple[Any, ...], optional
@@ -271,6 +273,10 @@ def create_node(
271273
torch.fx.Node
272274
The freshly inserted node with fully-populated `.meta`.
273275
"""
276+
assert isinstance(target, torch._ops.OpOverload) or (
277+
target is getitem
278+
), f"Invalid target {target}"
279+
274280
new_node = graph.call_function(target, args=args, kwargs=kwargs)
275281
if origin:
276282
assert isinstance(origin, torch.fx.Node), type(origin)

0 commit comments

Comments
 (0)