Skip to content

Commit 8b97b17

Browse files
authored
fix the concretereifier bug in dim_gen (#501)
* fix the concretereifier bug in dim_gen * fix graph_net/dimension_generalizer.py * fix graph_net/dimension_generalizer.py * fix the bug: The expanded size of the tensor doesn't match * fix the bug: The expanded size of the tensor doesn't match * fix graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py to fix bug about expand size * fix expand bug
1 parent 16ac11d commit 8b97b17

File tree

5 files changed

+16
-27
lines changed

5 files changed

+16
-27
lines changed

graph_net/config/small_sample_list_for_get_fusible_subgraph.txt

Lines changed: 0 additions & 10 deletions
This file was deleted.

graph_net/dimension_generalizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,5 +242,7 @@ def update_tensor_metas_by_dyn_dim_cstr(
242242
if tensor_meta.data is not None:
243243
assert isinstance(tensor_meta.data, (list, tuple))
244244
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
245-
doubled_data = [*tensor_meta.data, *tensor_meta.data]
246-
tensor_meta.data = doubled_data[:size]
245+
extended_tensor_data = list(tensor_meta.data)
246+
while len(extended_tensor_data) < size:
247+
extended_tensor_data.extend(extended_tensor_data)
248+
tensor_meta.data = extended_tensor_data[:size]

graph_net/tools/apply_dim_gen_passes.sh

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@
33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
44
os.path.dirname(graph_net.__file__))")
55

6-
python3 -m graph_net.model_path_handler \
6+
python3 -m graph_net.apply_sample_pass \
77
--model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt \
8-
--handler-config=$(base64 -w 0 <<EOF
8+
--sample-pass-file-path "$GRAPH_NET_ROOT/dimension_generalizer.py" \
9+
--sample-pass-class-name "ApplyDimGenPasses" \
10+
--sample-pass-config $(base64 -w 0 <<EOF
911
{
10-
"handler_path": "$GRAPH_NET_ROOT/dimension_generalizer.py",
11-
"handler_class_name": "ApplyDimGenPasses",
12-
"handler_config": {
13-
"resume": false,
14-
"output_dir": "/tmp/dimension_generalized_samples",
15-
"model_path_prefix": "$GRAPH_NET_ROOT/../",
16-
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
17-
"dimension_generalizer_class_name": "StaticToDynamic",
18-
"limits_handled_models": 10,
19-
"last_model_log_file": "/tmp/a.py"
20-
}
12+
"resume": false,
13+
"output_dir": "/tmp/dimension_generalized_samples",
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
15+
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
16+
"dimension_generalizer_class_name": "StaticToDynamic",
17+
"limits_handled_models": 40,
18+
"last_model_log_file": "/tmp/a.py"
2119
}
2220
EOF
2321
)

graph_net/torch/dim_gen_passes/non_batch_call_method_expand_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import torch.fx as fx
32
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
43
from collections import namedtuple

graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def match(self) -> bool:
1919
return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier()
2020

2121
def reify(self):
22-
assert self.need_reify()
22+
assert self.match()
2323
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
2424
reifier = self._get_map_nlp_sym_shapes_str2reifier()[sym_shapes_str]
2525
return reifier(self)

0 commit comments

Comments
 (0)