Skip to content

Commit 9a55a74

Browse files
authored
Fix the replay of tensor, not use dummy in decomposer. (#454)
1 parent 5885e0c commit 9a55a74

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

graph_net/torch/fx_graph_module_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from dataclasses import asdict
66

77

8-
def get_torch_module_and_inputs(model_path):
8+
def get_torch_module_and_inputs(model_path, use_dummy_inputs=True):
99
module = _get_torch_module(model_path)
1010
tensor_metas = _get_tensor_metas(model_path)
11-
inputs = _create_inputs_by_metas(module, tensor_metas)
11+
inputs = _create_inputs_by_metas(module, tensor_metas, use_dummy_inputs)
1212
return module, inputs
1313

1414

@@ -27,11 +27,11 @@ def _get_tensor_metas(model_path):
2727
]
2828

2929

30-
def _create_inputs_by_metas(module, tensor_metas):
30+
def _create_inputs_by_metas(module, tensor_metas, use_dummy_inputs):
3131
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
32-
from graph_net.torch.utils import get_dummy_named_tensors
32+
from graph_net.torch.utils import get_named_tensors
3333

34-
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
34+
named_tensors = get_named_tensors(tensor_meta_attrs_list, use_dummy_inputs)
3535
name2tensor = {k: v for k, v in named_tensors}
3636
return tuple(
3737
name2tensor[name] for name in inspect.signature(module.forward).parameters

graph_net/torch/graph_decomposer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __call__(self, rel_model_path):
130130
for k, v in self.config.items()
131131
if k in {"split_positions", "group_head_and_tail", "chain_style"}
132132
}
133-
module, inputs = get_torch_module_and_inputs(model_path)
133+
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
134134
gm = parse_immutable_model_path_into_sole_graph_module(model_path)
135135
try:
136136
# logger.warning("convert_to_submodules_graph-call-begin")
@@ -227,7 +227,7 @@ def __call__(self, rel_model_path):
227227
"group_head_and_tail": self.config.get("group_head_and_tail", False),
228228
"chain_style": self.config.get("chain_style", False),
229229
}
230-
module, inputs = get_torch_module_and_inputs(model_path)
230+
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
231231
gm = parse_sole_graph_module(module, inputs)
232232
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
233233
gm,

graph_net/torch/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,18 @@ def convert_tensor_meta_attrs_list_to_named_tensors(tensor_meta_attrs_list):
236236
return ret
237237

238238

239-
def get_dummy_named_tensors(tensor_meta_attrs_list):
239+
def get_named_tensors(tensor_meta_attrs_list, use_dummy_inputs):
240240
tensors_wrappers = convert_tensor_meta_attrs_list_to_tensors_wrappers(
241241
tensor_meta_attrs_list
242242
)
243243
ret = []
244244
for i, tensors_wrapper in enumerate(tensors_wrappers):
245245
name = tensors_wrapper["name"]
246246
# shape = tensors_wrapper["info"]['shape']
247-
tensor = get_dummy_tensor(tensors_wrapper)
247+
if use_dummy_inputs:
248+
tensor = get_dummy_tensor(tensors_wrapper)
249+
else:
250+
tensor = replay_tensor(tensors_wrapper)
248251
ret.append((name, tensor))
249252
return ret
250253

0 commit comments

Comments
 (0)