1717
1818if TYPE_CHECKING :
1919 import torch .fx
20+ from operator import getitem
21+
2022import torch
2123from torch .export import ExportedProgram
2224
2628from tico .utils .graph import create_node
2729from tico .utils .passes import PassBase , PassResult
2830from tico .utils .trace_decorators import trace_graph_diff_on_pass
29- from tico .utils .utils import is_target_node
31+ from tico .utils .utils import is_target_node , set_new_meta_val
3032from tico .utils .validate_args_kwargs import (
3133 AvgPool2dArgs ,
3234 Conv2DArgs ,
3537 DequantizePerTensorArgs ,
3638 InstanceNormArgs ,
3739 MaxPool2dWithIndicesArgs ,
40+ TopKArgs ,
3841)
3942
4043
@@ -434,6 +437,52 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool:
434437 modified = True
435438 return modified
436439
440+ def legalize_top_k (self , exported_program , node ) -> bool :
441+ logger = logging .getLogger (__name__ )
442+ modified = False
443+
444+ graph_module = exported_program .graph_module
445+ graph = graph_module .graph
446+
447+ args = TopKArgs (* node .args , ** node .kwargs ) # type: ignore[arg-type]
448+ input_ = args .input
449+ k = args .k
450+ dim = args .dim
451+ # TODO: Check dim == -1
452+ with graph .inserting_after (input_ ):
453+ circle_topk = create_node (
454+ graph ,
455+ torch .ops .circle_custom .top_k ,
456+ args = (input_ , k ),
457+ origin = input_ ,
458+ )
459+ set_new_meta_val (circle_topk )
460+
461+ with graph .inserting_after (circle_topk ):
462+ topk_values = create_node (
463+ graph , getitem , args = (circle_topk , 0 ), origin = circle_topk
464+ )
465+ set_new_meta_val (topk_values )
466+ topk_indices = create_node (
467+ graph , getitem , args = (circle_topk , 1 ), origin = circle_topk
468+ )
469+ set_new_meta_val (topk_indices )
470+ with graph .inserting_after (topk_indices ):
471+ topk_indices_int32 = create_node (
472+ graph ,
473+ torch .ops .aten .to .dtype ,
474+ args = (topk_indices , torch .int32 ),
475+ origin = node ,
476+ )
477+ set_new_meta_val (topk_indices_int32 )
478+ get_item , get_item_1 = node .users .keys ()
479+ get_item .replace_all_uses_with (topk_values , propagate_meta = False )
480+ get_item_1 .replace_all_uses_with (topk_indices_int32 , propagate_meta = False )
481+
482+ logger .debug (f"{ node .name } is replaced with { circle_topk .name } " )
483+ modified = True
484+ return modified
485+
437486 def call (self , exported_program : ExportedProgram ) -> PassResult :
438487 target_to_legalize_func = {
439488 torch .ops .aten .conv2d .default : self .legalize_conv2d ,
@@ -442,6 +491,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
442491 torch .ops .aten .max_pool2d_with_indices .default : self .legalize_max_pool2d_with_indices ,
443492 torch .ops .aten .avg_pool2d .default : self .legalize_avg_pool2d ,
444493 torch .ops .aten .instance_norm .default : self .legalize_instance_norm ,
494+ torch .ops .aten .topk .default : self .legalize_top_k ,
445495 }
446496
447497 graph_module = exported_program .graph_module
0 commit comments