66import pickletools
77import shutil
88from abc import ABC , abstractmethod
9- from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
9+ from typing import Any , Dict , List , Optional , Sequence , Tuple
1010
1111import torch
12- from torch ._inductor .codecache import FxGraphCachePickler , sha256_hash
13- from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
12+ from torch ._inductor .codecache import sha256_hash
1413from torch_tensorrt ._Input import Input
1514from torch_tensorrt .dynamo ._settings import (
1615 _SETTINGS_TO_BE_ENGINE_INVARIANT ,
@@ -49,17 +48,38 @@ def get_hash(
4948
5049 Args:
5150 gm (torch.fx.GraphModule): GraphModule to hash
51+ input_specs (Sequence[Input]): input specs for the GraphModule
52+ settings (CompilationSettings): compilation settings for the GraphModule
5253
5354 Returns:
5455 str: hash value of the GraphModule
5556 """
56- # parameters are set to 0
57- with unset_fake_temporarily ():
58- new_gm = copy .deepcopy (gm )
59- for name , param in new_gm .named_parameters ():
60- param .data .zero_ ()
6157
62- graph_hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
58+ def canonicalize_graph (graph : torch .fx .Graph ) -> str :
59+ """Canonicalize the graph to a string for isomorphic graph comparison
60+
61+ Args:
62+ graph (torch.fx.Graph): graph to canonicalize
63+
64+ Returns:
65+ str: canonicalized graph string
66+ """
67+ canonical_nodes = []
68+ input_counter = 0
69+
70+ for node in graph .nodes :
71+ if node .op == "placeholder" :
72+ canonical_nodes .append (f"placeholder_input_{ input_counter } " )
73+ input_counter += 1
74+ else :
75+ canonical_nodes .append (f"{ node .op } _{ node .target } " )
76+
77+ return " " .join (canonical_nodes )
78+
79+ graph_str = canonicalize_graph (gm .graph )
80+ _LOGGER .debug (f"graph_str:\n { graph_str } " )
81+
82+ graph_hash = sha256_hash (graph_str .encode ())
6383
6484 input_spec_strs = [str (i ) for i in input_specs ]
6585 with io .BytesIO () as stream :
@@ -75,7 +95,7 @@ def get_hash(
7595 engine_specs_data = pickletools .optimize (engine_specs_data )
7696 engine_specs_hash = sha256_hash (engine_specs_data )
7797
78- hash_val : str = graph_hash_val + input_specs_hash + engine_specs_hash
98+ hash_val : str = graph_hash + input_specs_hash + engine_specs_hash
7999
80100 return hash_val
81101
0 commit comments