Skip to content

Commit 2cb20e6

Browse files
committed
Support TopK operator
1 parent 0ba9f5b commit 2cb20e6

File tree

6 files changed

+231
-1
lines changed

6 files changed

+231
-1
lines changed

test/modules/op/top_k.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
from test.modules.base import TestModuleBase
18+
from test.utils.tag import use_onert
19+
20+
# CircleInterpreter doesn't support TopK operator
21+
@use_onert
22+
class SimpleTopK(TestModuleBase):
23+
def __init__(self):
24+
super().__init__()
25+
26+
def forward(self, x):
27+
values, indices = torch.topk(x, 2)
28+
return values, indices
29+
30+
def get_example_inputs(self):
31+
batch_size = 1
32+
seq_len = 63
33+
num_experts = 8
34+
return (torch.randn(batch_size * seq_len, num_experts),), {}

tico/passes/legalize_predefined_layout_operators.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
if TYPE_CHECKING:
1919
import torch.fx
20+
from operator import getitem
21+
2022
import torch
2123
from torch.export import ExportedProgram
2224

@@ -26,7 +28,7 @@
2628
from tico.utils.graph import create_node
2729
from tico.utils.passes import PassBase, PassResult
2830
from tico.utils.trace_decorators import trace_graph_diff_on_pass
29-
from tico.utils.utils import is_target_node
31+
from tico.utils.utils import is_target_node, set_new_meta_val
3032
from tico.utils.validate_args_kwargs import (
3133
AvgPool2dArgs,
3234
Conv2DArgs,
@@ -35,6 +37,7 @@
3537
DequantizePerTensorArgs,
3638
InstanceNormArgs,
3739
MaxPool2dWithIndicesArgs,
40+
TopKArgs,
3841
)
3942

4043

@@ -434,6 +437,52 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool:
434437
modified = True
435438
return modified
436439

440+
def legalize_top_k(self, exported_program, node) -> bool:
441+
logger = logging.getLogger(__name__)
442+
modified = False
443+
444+
graph_module = exported_program.graph_module
445+
graph = graph_module.graph
446+
447+
args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
448+
input_ = args.input
449+
k = args.k
450+
dim = args.dim
451+
# TODO: Check dim == -1
452+
with graph.inserting_after(input_):
453+
circle_topk = create_node(
454+
graph,
455+
torch.ops.circle_custom.top_k,
456+
args=(input_, k),
457+
origin=input_,
458+
)
459+
set_new_meta_val(circle_topk)
460+
461+
with graph.inserting_after(circle_topk):
462+
topk_values = create_node(
463+
graph, getitem, args=(circle_topk, 0), origin=circle_topk
464+
)
465+
set_new_meta_val(topk_values)
466+
topk_indices = create_node(
467+
graph, getitem, args=(circle_topk, 1), origin=circle_topk
468+
)
469+
set_new_meta_val(topk_indices)
470+
with graph.inserting_after(topk_indices):
471+
topk_indices_int32 = create_node(
472+
graph,
473+
torch.ops.aten.to.dtype,
474+
args=(topk_indices, torch.int32),
475+
origin=node,
476+
)
477+
set_new_meta_val(topk_indices_int32)
478+
get_item, get_item_1 = node.users.keys()
479+
get_item.replace_all_uses_with(topk_values, propagate_meta=False)
480+
get_item_1.replace_all_uses_with(topk_indices_int32, propagate_meta=False)
481+
482+
logger.debug(f"{node.name} is replaced with {circle_topk.name}")
483+
modified = True
484+
return modified
485+
437486
def call(self, exported_program: ExportedProgram) -> PassResult:
438487
target_to_legalize_func = {
439488
torch.ops.aten.conv2d.default: self.legalize_conv2d,
@@ -442,6 +491,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
442491
torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
443492
torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
444493
torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
494+
torch.ops.aten.topk.default: self.legalize_top_k,
445495
}
446496

447497
graph_module = exported_program.graph_module

tico/serialize/circle_serializer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
multiple_output_ops = [
3333
torch.ops.aten.split_with_sizes.default,
3434
torch.ops.aten.max.dim,
35+
torch.ops.aten.topk.default,
36+
torch.ops.circle_custom.top_k,
3537
]
3638

3739

@@ -142,6 +144,8 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None:
142144
if node.target in multiple_output_ops:
143145
continue
144146
node_val = node.meta["val"]
147+
if not hasattr(node_val, "layout"):
148+
breakpoint()
145149
if node_val.layout != torch.strided:
146150
raise RuntimeError(
147151
f"Only support dense tensors (node layout: {node_val.layout})"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import torch
20+
from circle_schema import circle
21+
22+
from tico.serialize.circle_graph import CircleSubgraph
23+
from tico.serialize.circle_mapping import (
24+
circle_legalize_dtype_to,
25+
extract_circle_shape,
26+
extract_shape,
27+
extract_torch_dtype,
28+
)
29+
from tico.serialize.operators.hashable_opcode import OpCode
30+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
31+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
32+
from tico.utils.validate_args_kwargs import TopKArgs
33+
34+
35+
@register_node_visitor
36+
class TopkVisitor(NodeVisitor):
37+
""" """
38+
39+
target: List[torch._ops.OpOverload] = [
40+
torch.ops.circle_custom.top_k,
41+
]
42+
43+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
44+
super().__init__(op_codes, graph)
45+
46+
def define_topk_node(
47+
self, inputs: List, outputs: List
48+
) -> circle.Operator.OperatorT:
49+
op_index = get_op_index(
50+
circle.BuiltinOperator.BuiltinOperator.TOPK_V2, self._op_codes
51+
)
52+
53+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
54+
55+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.TopKV2Options
56+
option = circle.TopKV2Options.TopKV2OptionsT()
57+
operator.builtinOptions = option
58+
59+
return operator
60+
61+
def define_node(
62+
self,
63+
node: torch.fx.Node,
64+
) -> circle.Operator.OperatorT:
65+
args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
66+
input = args.input
67+
k = args.k
68+
69+
input_shape = extract_circle_shape(input)
70+
k_i32 = circle_legalize_dtype_to(k, dtype=torch.int32)
71+
assert args.dim == -1 or args.dim == len(input_shape) - 1
72+
73+
inputs = [input, k_i32]
74+
75+
outputs = [i for i in node.users.keys()]
76+
77+
topk_node: circle.Operator.OperatorT = self.define_topk_node(inputs, outputs)
78+
79+
return topk_node

tico/utils/register_custom_op.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.library import custom_op, register_fake
2020

2121
from tico.utils.mx.mx_ops import _quantize_mx
22+
from tico.utils.validate_args_kwargs import TopKArgs
2223

2324
# Note that an operator assumes input tensor has NHWC format.
2425
def CircleResizeNearestNeighbor():
@@ -662,6 +663,48 @@ def _(
662663
return input.new_empty(input.size())
663664

664665

666+
def CircleTopK():
667+
@custom_op(
668+
"circle_custom::top_k",
669+
mutates_args=(),
670+
schema="(Tensor input, int k) -> (Tensor, Tensor)",
671+
)
672+
def top_k(
673+
input: torch.Tensor,
674+
k: int,
675+
dim: int = -1,
676+
largest: bool = True,
677+
sorted: bool = True,
678+
) -> tuple[torch.Tensor]:
679+
args = TopKArgs(input, k, dim, largest, sorted)
680+
topk_out_0, topk_out_1 = torch.ops.aten.topk.default(*args)
681+
topk_out_1_int32 = torch.ops.aten.to.dtype(topk_out_1, dtype=torch.int32)
682+
683+
return (
684+
topk_out_0,
685+
topk_out_1_int32,
686+
)
687+
688+
@register_fake("circle_custom::top_k")
689+
def _(
690+
input: FakeTensor,
691+
k: int,
692+
dim: int = -1,
693+
largest: bool = True,
694+
sorted: bool = True,
695+
) -> tuple[FakeTensor]:
696+
assert dim == -1
697+
assert largest is True
698+
assert sorted is True
699+
topk_out0, topk_out1 = torch.ops.aten.topk.default(input, k, dim)
700+
# topk_out_1_int32 = torch.ops.aten.to.dtype(topk_out_1, dtype=torch.int32)
701+
702+
return (
703+
topk_out0,
704+
topk_out1.new_empty(size=topk_out1.size(), dtype=torch.int32),
705+
)
706+
707+
665708
def CircleQuantizeMX():
666709
# This operator conducts fake-quantization of microscaling
667710
# NOTE Why using "quantize"_mx not "fake_quantize"_mx?
@@ -715,3 +758,4 @@ def RegisterOps():
715758
CircleAvgPool2D()
716759
CircleInstanceNorm()
717760
CircleQuantizeMX()
761+
CircleTopK()

tico/utils/validate_args_kwargs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,25 @@ class ToDtypeLayoutArgs:
11481148
memory_format: Optional[torch.memory_format] = None
11491149

11501150

1151+
@enforce_type
1152+
@dataclass
1153+
class TopKArgs:
1154+
"""
1155+
topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
1156+
"""
1157+
1158+
input: torch.fx.Node
1159+
k: int
1160+
dim: int = -1
1161+
largest: bool = True
1162+
sorted: bool = True
1163+
1164+
def __post_init__(self):
1165+
1166+
assert self.largest is True, "Only support largest=True"
1167+
assert self.sorted is True, "Only support sorted=True"
1168+
1169+
11511170
@enforce_type
11521171
@dataclass
11531172
class UnSqueezeArgs:

0 commit comments

Comments
 (0)