File tree Expand file tree Collapse file tree 2 files changed +22
-8
lines changed Expand file tree Collapse file tree 2 files changed +22
-8
lines changed Original file line number Diff line number Diff line change 6
6
from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
7
7
clean_up_graph_after_modifications ,
8
8
)
9
- from torch_tensorrt .dynamo .utils import get_metadata , set_metadata
9
+ from torch_tensorrt .dynamo .utils import copy_metadata
10
10
11
11
logger = logging .getLogger (__name__ )
12
12
@@ -26,14 +26,14 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
26
26
def replacement (input : torch .Tensor , shape : List [torch .SymInt ]) -> torch .Tensor :
27
27
return replacement_op (input , shape )
28
28
29
- # Store metadata of the orig_op
30
- metadata = get_metadata ( gm , orig_op )
31
-
32
- if torch . fx . subgraph_rewriter . replace_pattern ( gm , orig , replacement ) :
29
+ match_and_replacements = torch . fx . subgraph_rewriter . _replace_pattern (
30
+ gm , orig , replacement
31
+ )
32
+ if match_and_replacements :
33
33
gm = clean_up_graph_after_modifications (gm )
34
34
logger .debug (f"Graph after replacing view with reshape:\n { gm .graph } " )
35
35
36
36
# Copy the orig_op's metadata to the replacement op
37
- set_metadata ( gm , replacement_op , metadata )
37
+ copy_metadata ( match_and_replacements )
38
38
39
39
return gm
Original file line number Diff line number Diff line change 11
11
import tensorrt as trt
12
12
import torch
13
13
from torch ._subclasses .fake_tensor import FakeTensor
14
-
15
- from packaging import version
16
14
from torch_tensorrt ._Device import Device
17
15
from torch_tensorrt ._enums import dtype
18
16
from torch_tensorrt ._features import ENABLED_FEATURES
22
20
from torch_tensorrt .dynamo ._engine_cache import BaseEngineCache
23
21
from torch_tensorrt .dynamo ._settings import CompilationSettings
24
22
23
+ from packaging import version
24
+
25
25
from .types import TRTDataType
26
26
27
27
logger = logging .getLogger (__name__ )
@@ -716,6 +716,20 @@ def set_metadata(
716
716
node .meta = metadata [idx ]
717
717
718
718
719
+ def copy_metadata (match_and_replacements : List [Any ]) -> None :
720
+ """
721
+ Copy the metadata from anchor node to the replacement node. This should be used
722
+ if the anchor node is replaced with only a single replacement node i.e one-one replacement.
723
+ """
724
+ for match_and_replacement in match_and_replacements :
725
+ anchor_node = match_and_replacement .nodes_map [match_and_replacement .anchor ]
726
+ assert (
727
+ len (match_and_replacement .replacements ) == 1
728
+ ), "Found more than 1 replacements for the anchor node."
729
+ replacement_node = match_and_replacement .replacements [0 ]
730
+ replacement_node .meta = anchor_node .meta
731
+
732
+
719
733
def flatten_nodes (nodes : Any ) -> List [torch .fx .node .Node ]:
720
734
ret = []
721
735
if isinstance (nodes , torch .fx .node .Node ):
You can’t perform that action at this time.
0 commit comments