Skip to content

Commit 61e5ba4

Browse files
committed
more dimension generalization pass
1 parent f35ee0a commit 61e5ba4

9 files changed

+835
-33
lines changed

graph_net/config/empty_cstr_torch_samples_list.txt

Lines changed: 503 additions & 0 deletions
Large diffs are not rendered by default.

graph_net/constraint_util.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, config=None):
2525
self.model_runnable_predicator = self._make_model_runnable_predicator(
2626
self.config
2727
)
28-
self.num_successful_handled_models = 0
28+
self.num_handled_models = 0
2929

3030
def _make_data_input_predicator(self, config):
3131
module = load_module(config["data_input_predicator_filepath"])
@@ -51,7 +51,7 @@ def _make_config(
5151
model_path_prefix="",
5252
resume=False,
5353
last_model_log_file=None,
54-
limits_successfully_handled_models=None,
54+
limits_handled_models=None,
5555
):
5656
if data_input_predicator_config is None:
5757
data_input_predicator_config = {}
@@ -72,7 +72,7 @@ def _make_config(
7272
"dimension_generalizer_class_name": dimension_generalizer_class_name,
7373
"dimension_generalizer_config": dimension_generalizer_config,
7474
"last_model_log_file": last_model_log_file,
75-
"limits_successfully_handled_models": limits_successfully_handled_models,
75+
"limits_handled_models": limits_handled_models,
7676
}
7777

7878
def __call__(self, model_path):
@@ -125,16 +125,15 @@ def is_dyn_dim_cstr_feasible(dyn_dim_cstr):
125125
)
126126
self._save_dyn_dim_cstr(dyn_dim_cstr, model_path)
127127
self._save_dim_gen_pass_names(dim_gen_pass_names, model_path)
128-
if len(dyn_dim_cstr.symbols) > 0:
129-
self.num_successful_handled_models += 1
130-
limits = self.config["limits_successfully_handled_models"]
131-
if limits is not None:
132-
if self.num_successful_handled_models > limits:
133-
print(
134-
"`num_successful_handled_models` exceeds config `limits_successfully_handled_models`",
135-
file=sys.stderr,
136-
)
137-
sys.exit(0)
128+
self.num_handled_models += 1
129+
limits = self.config["limits_handled_models"]
130+
if limits is not None:
131+
if self.num_handled_models >= limits:
132+
print(
133+
"`num_handled_models` exceeds config `limits_handled_models`",
134+
file=sys.stderr,
135+
)
136+
sys.exit(0)
138137

139138
def get_dimension_generalizer(self):
140139
if hasattr(self, "_dim_generalizer"):
@@ -159,6 +158,7 @@ def get_model(self, model_path):
159158
def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
160159
logging.warning(f"enter _try_dimension_generalization")
161160
if self.config["dimension_generalizer_filepath"] is None:
161+
self._save_model_to_log_file(model_path)
162162
yield model_path, ()
163163
return
164164
model = self.get_model(model_path)
@@ -168,6 +168,7 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
168168
need_rewrite = dim_gen_pass.need_rewrite(inputs)
169169
logging.warning(f"after need_rewrite")
170170
if not need_rewrite:
171+
self._save_model_to_log_file(model_path)
171172
yield model_path, ()
172173
return
173174

@@ -177,11 +178,14 @@ def _try_dimension_generalization(self, dim_axes_pairs, model_path, inputs):
177178
with tempfile.TemporaryDirectory() as tmp_dir:
178179
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
179180
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
180-
if self.config["last_model_log_file"] is not None:
181-
log_file = Path(self.config["last_model_log_file"])
182-
shutil.copy(Path(tmp_dir) / "model.py", log_file)
181+
self._save_model_to_log_file(tmp_dir)
183182
yield tmp_dir, dim_gen_pass.get_pass_names()
184183

184+
def _save_model_to_log_file(self, model_path):
185+
if self.config["last_model_log_file"] is not None:
186+
log_file = Path(self.config["last_model_log_file"])
187+
shutil.copy(Path(model_path) / "model.py", log_file)
188+
185189
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
186190
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
187191

@@ -324,7 +328,7 @@ def append_dim_gen_pass_names(dim_gen_pass_names):
324328
)
325329

326330
for i, picked_dim in enumerate(unqiue_dims):
327-
logging.warning(f"{i=} {picked_dim=}")
331+
logging.warning(f"{i=} {picked_dim=} {dim2axes[picked_dim]=}")
328332
cur_dyn_dim_cstr = copy.deepcopy(dyn_dim_cstr)
329333

330334
def filter_fn(input_name, input_idx, axis, dim):

graph_net/tools/batch_init_input_tensor_constraints.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,27 @@ config_json_str=$(cat <<EOF
1919
"model_runnable_predicator_class_name": "$model_runnable_predicator",
2020
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
2121
"dimension_generalizer_class_name": "StaticToDynamic",
22+
"dimension_generalizer_config": {
23+
"pass_names": [
24+
"batch_call_method_view_pass",
25+
"tuple_arg_call_method_view_pass",
26+
"naive_call_method_reshape_pass",
27+
"naive_call_method_expand_pass",
28+
"non_batch_call_method_expand_pass",
29+
"non_batch_call_function_arange_pass",
30+
"non_batch_call_function_getitem_slice_pass",
31+
"non_batch_call_function_full_pass",
32+
"non_batch_call_function_full_plus_one_pass",
33+
"non_batch_call_function_zeros_pass",
34+
"non_batch_call_function_arange_plus_one_pass"
35+
]
36+
},
37+
"limits_handled_models": 1,
2238
"last_model_log_file": "/tmp/a.py"
2339
}
2440
}
2541
EOF
2642
)
2743
CONFIG=$(echo $config_json_str | base64 -w 0)
2844

29-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/torch_samples_list.txt --handler-config=$CONFIG
45+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG

graph_net/torch/constraint_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ def __init__(self, config):
88
self.config = config
99

1010
def __call__(self, model_path, input_var_name: str) -> bool:
11-
return not ("_self_" in input_var_name)
11+
return not (
12+
"_self_" in input_var_name or "_instance_modules_" in input_var_name
13+
)
1214

1315

1416
class ModelRunnablePredicator:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import torch.fx as fx
3+
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
4+
import os
5+
6+
7+
class ConcretePass(DimensionGeneralizationPass):
8+
def __init__(self, *args, **kwargs):
9+
super().__init__(*args, **kwargs)
10+
11+
def get_pass_name(cls) -> bool:
12+
return os.path.basename(__file__)[:-3]
13+
14+
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
15+
if 0 not in self.axes:
16+
return False
17+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
18+
19+
def _node_need_rewrite(self, node) -> bool:
20+
if not (node.op == "call_method"):
21+
return False
22+
if not (node.target == "view"):
23+
return False
24+
if not (len(node.args) >= 2):
25+
return False
26+
if not (isinstance(node.args[1], int)):
27+
return False
28+
if not (self.dim == node.args[1]):
29+
return False
30+
return True
31+
32+
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
33+
"""
34+
Fx Pass: Replaces hardcoded constants in 'view' ops that match an input tensor dimension
35+
with a dynamic 'size()' call. The primary goal is to dynamicize the batch size (axis 0).
36+
"""
37+
# Create a new graph to hold the rewritten nodes
38+
new_graph = fx.Graph()
39+
40+
# Create a map to link nodes from the old graph to nodes in the new graph
41+
val_map = {}
42+
43+
def get_new_tuple_args(input_tensor_node, view_args):
44+
# --- Dependency on ShapeProp Results ---
45+
# input_shape is the static shape (e.g., batch_size, C, H, W)
46+
input_meta = input_tensor_node.meta.get("tensor_meta")
47+
if input_meta is None:
48+
raise RuntimeError(
49+
f"Node {input_tensor_node.name} lacks tensor_meta. Did ShapeProp run?"
50+
)
51+
52+
input_shape = input_meta.shape
53+
54+
# Find the new list of view arguments
55+
new_view_args = []
56+
for axis_idx, target_dim in enumerate(view_args):
57+
if not isinstance(target_dim, int) or target_dim < 1:
58+
new_view_args.append(
59+
val_map[target_dim] if target_dim in val_map else target_dim
60+
)
61+
continue
62+
63+
if axis_idx == 0 and target_dim == input_shape[axis_idx]:
64+
new_input_node = val_map[input_tensor_node]
65+
size_node = new_graph.call_method(
66+
"size", args=(new_input_node, axis_idx)
67+
)
68+
best_match = size_node
69+
else:
70+
best_match = target_dim
71+
new_view_args.append(best_match)
72+
return tuple(new_view_args)
73+
74+
for node in traced_module.graph.nodes:
75+
if self._node_need_rewrite(node):
76+
# Get the input tensor node
77+
input_tensor_node = node.args[0]
78+
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
79+
view_args = node.args[1:]
80+
print(f"{view_args=}")
81+
new_view_args = get_new_tuple_args(input_tensor_node, view_args)
82+
83+
# --- Rebuild the view node ---
84+
# 1. Map the input tensor node to the new graph node
85+
new_input_node = val_map[input_tensor_node]
86+
87+
# 2. Insert the new view node into the new graph
88+
# with new_graph.inserting_after(new_input_node):
89+
new_node = new_graph.call_method(
90+
"view", args=(new_input_node, *new_view_args)
91+
)
92+
93+
# 3. Map the old node to the new node
94+
val_map[node] = new_node
95+
96+
else:
97+
# Copy other nodes to the new graph
98+
new_node = new_graph.node_copy(node, lambda x: val_map[x])
99+
val_map[node] = new_node
100+
101+
# Replace the old graph with the new graph and return
102+
traced_module.graph = new_graph
103+
traced_module.recompile()
104+
return traced_module

graph_net/torch/dim_gen_passes/naive_call_method_view_pass.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@ def get_pass_name(cls) -> bool:
1414
def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
1515
if 0 not in self.axes:
1616
return False
17-
for node in traced_module.graph.nodes:
18-
if node.op == "call_method" and node.target == "view":
19-
return True
20-
return False
17+
return any(self._node_need_rewrite(node) for node in traced_module.graph.nodes)
18+
19+
def _node_need_rewrite(self, node) -> bool:
20+
if not (node.op == "call_method"):
21+
return False
22+
if not (node.target == "view"):
23+
return False
24+
print(f"{self.dim=} {node.args[1:]=}")
25+
if not (self.dim in node.args[1:]):
26+
return False
27+
return True
2128

2229
def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
2330
"""

0 commit comments

Comments
 (0)