File tree Expand file tree Collapse file tree 2 files changed +22
-8
lines changed Expand file tree Collapse file tree 2 files changed +22
-8
lines changed Original file line number Diff line number Diff line change 66from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
77 clean_up_graph_after_modifications ,
88)
9- from torch_tensorrt .dynamo .utils import get_metadata , set_metadata
9+ from torch_tensorrt .dynamo .utils import copy_metadata
1010
1111logger = logging .getLogger (__name__ )
1212
@@ -26,14 +26,14 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
2626 def replacement (input : torch .Tensor , shape : List [torch .SymInt ]) -> torch .Tensor :
2727 return replacement_op (input , shape )
2828
29- # Store metadata of the orig_op
30- metadata = get_metadata ( gm , orig_op )
31-
32- if torch . fx . subgraph_rewriter . replace_pattern ( gm , orig , replacement ) :
29+ match_and_replacements = torch . fx . subgraph_rewriter . _replace_pattern (
30+ gm , orig , replacement
31+ )
32+ if match_and_replacements :
3333 gm = clean_up_graph_after_modifications (gm )
3434 logger .debug (f"Graph after replacing view with reshape:\n { gm .graph } " )
3535
3636 # Copy the orig_op's metadata to the replacement op
37- set_metadata ( gm , replacement_op , metadata )
37+ copy_metadata ( match_and_replacements )
3838
3939 return gm
Original file line number Diff line number Diff line change 1111import tensorrt as trt
1212import torch
1313from torch ._subclasses .fake_tensor import FakeTensor
14-
15- from packaging import version
1614from torch_tensorrt ._Device import Device
1715from torch_tensorrt ._enums import dtype
1816from torch_tensorrt ._features import ENABLED_FEATURES
2220from torch_tensorrt .dynamo ._engine_cache import BaseEngineCache
2321from torch_tensorrt .dynamo ._settings import CompilationSettings
2422
23+ from packaging import version
24+
2525from .types import TRTDataType
2626
2727logger = logging .getLogger (__name__ )
@@ -716,6 +716,20 @@ def set_metadata(
716716 node .meta = metadata [idx ]
717717
718718
719+ def copy_metadata (match_and_replacements : List [Any ]) -> None :
720+ """
721+ Copy the metadata from anchor node to the replacement node. This should be used
722+ if the anchor node is replaced with only a single replacement node i.e one-one replacement.
723+ """
724+ for match_and_replacement in match_and_replacements :
725+ anchor_node = match_and_replacement .nodes_map [match_and_replacement .anchor ]
726+ assert (
727+ len (match_and_replacement .replacements ) == 1
728+ ), "Found more than 1 replacements for the anchor node."
729+ replacement_node = match_and_replacement .replacements [0 ]
730+ replacement_node .meta = anchor_node .meta
731+
732+
719733def flatten_nodes (nodes : Any ) -> List [torch .fx .node .Node ]:
720734 ret = []
721735 if isinstance (nodes , torch .fx .node .Node ):
You can’t perform that action at this time.
0 commit comments