Skip to content

Commit 184f601

Browse files
authored
fix: Fix copying metadata during lowering (#3320)
1 parent de39fa3 commit 184f601

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
77
clean_up_graph_after_modifications,
88
)
9-
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata
9+
from torch_tensorrt.dynamo.utils import copy_metadata
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -26,14 +26,14 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
2626
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
2727
return replacement_op(input, shape)
2828

29-
# Store metadata of the orig_op
30-
metadata = get_metadata(gm, orig_op)
31-
32-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
29+
match_and_replacements = torch.fx.subgraph_rewriter._replace_pattern(
30+
gm, orig, replacement
31+
)
32+
if match_and_replacements:
3333
gm = clean_up_graph_after_modifications(gm)
3434
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
3535

3636
# Copy the orig_op's metadata to the replacement op
37-
set_metadata(gm, replacement_op, metadata)
37+
copy_metadata(match_and_replacements)
3838

3939
return gm

py/torch_tensorrt/dynamo/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import tensorrt as trt
1212
import torch
1313
from torch._subclasses.fake_tensor import FakeTensor
14-
15-
from packaging import version
1614
from torch_tensorrt._Device import Device
1715
from torch_tensorrt._enums import dtype
1816
from torch_tensorrt._features import ENABLED_FEATURES
@@ -22,6 +20,8 @@
2220
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
2321
from torch_tensorrt.dynamo._settings import CompilationSettings
2422

23+
from packaging import version
24+
2525
from .types import TRTDataType
2626

2727
logger = logging.getLogger(__name__)
@@ -716,6 +716,20 @@ def set_metadata(
716716
node.meta = metadata[idx]
717717

718718

719+
def copy_metadata(match_and_replacements: List[Any]) -> None:
720+
"""
721+
Copy the metadata from anchor node to the replacement node. This should be used
722+
if the anchor node is replaced with only a single replacement node i.e one-one replacement.
723+
"""
724+
for match_and_replacement in match_and_replacements:
725+
anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
726+
assert (
727+
len(match_and_replacement.replacements) == 1
728+
), "Found more than 1 replacements for the anchor node."
729+
replacement_node = match_and_replacement.replacements[0]
730+
replacement_node.meta = anchor_node.meta
731+
732+
719733
def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
720734
ret = []
721735
if isinstance(nodes, torch.fx.node.Node):

0 commit comments

Comments
 (0)