Skip to content

Commit f8524bb

Browse files
committed
backup code for multi_dim_size
1 parent 6be9482 commit f8524bb

File tree

9 files changed

+320
-25
lines changed

9 files changed

+320
-25
lines changed

graph_net/torch/constraint_util.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import copy
2+
3+
4+
def get_all_symbol_names(constraint_attrs_list):
5+
unique_symbol_names = []
6+
for constraint_attrs in constraint_attrs_list:
7+
for dim in constraint_attrs["shape"]:
8+
if isinstance(dim, int):
9+
continue
10+
assert isinstance(dim, dict)
11+
if dim["symbol_name"] in unique_symbol_names:
12+
continue
13+
unique_symbol_names.append(dim["symbol_name"])
14+
15+
return unique_symbol_names
16+
17+
18+
def reify_symboli_dims(constraint_attrs_list, symbol_names):
19+
def try_reify_dim(dim):
20+
if isinstance(dim, int):
21+
return dim
22+
assert isinstance(dim, dict)
23+
if dim["symbol_name"] not in symbol_names:
24+
return dim
25+
return dim["example_value"]
26+
27+
constraint_attrs_list = copy.deepcopy(constraint_attrs_list)
28+
for constraint_attrs in constraint_attrs_list:
29+
constraint_attrs["shape"] = [
30+
try_reify_dim(dim) for dim in constraint_attrs["shape"]
31+
]
32+
return constraint_attrs_list
33+
34+
35+
def modify_dim_example_value(constraint_attrs_list, symbol_name, modifier):
36+
def modify_dim(dim):
37+
if isinstance(dim, int):
38+
return
39+
assert isinstance(dim, dict)
40+
dim["example_value"] = modifier(dim["example_value"])
41+
42+
constraint_attrs_list = copy.deepcopy(constraint_attrs_list)
43+
for constraint_attrs in constraint_attrs_list:
44+
for dim in constraint_attrs["shape"]:
45+
modify_dim(dim)
46+
return constraint_attrs_list
47+
48+
49+
def symbolic_dims_all_reified(constraint_attrs_list):
50+
for constraint_attrs in constraint_attrs_list:
51+
for dim in constraint_attrs["shape"]:
52+
if not isinstance(dim, int):
53+
return False
54+
return True
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from graph_net.torch import utils
2+
import argparse
3+
import torch
4+
import logging
5+
from pathlib import Path
6+
from typing import Type, Any
7+
import sys
8+
from graph_net.torch.imp_util import load_class_from_file
9+
import hashlib
10+
from contextlib import contextmanager
11+
import json
12+
import inspect
13+
import imp_util
14+
import record_util
15+
import copy
16+
17+
18+
def main(args):
19+
model_path = args.model_path
20+
name2input_param_attrs = _get_name2input_param_attrs(model_path)
21+
name_and_annotation_types = _get_name_and_annotation_types(model_path)
22+
input_name_and_meta_attrs = _get_input_name_and_meta_attrs(
23+
name2input_param_attrs, name_and_annotation_types
24+
)
25+
input_name_and_constraint_attrs = _get_input_name_and_constraint_attrs(
26+
input_name_and_meta_attrs
27+
)
28+
_dump_input_name_and_constraint_attrs(
29+
input_name_and_constraint_attrs, args.output_path
30+
)
31+
32+
33+
def _dump_input_name_and_constraint_attrs(input_name_and_constraint_attrs, output_path):
34+
py_code = record_util.serialize_to_py_code(
35+
[attr for _, attr in input_name_and_constraint_attrs],
36+
class_prefix="ProgramInputConstraint",
37+
)
38+
print(f"{output_path=}")
39+
with open(output_path, "w") as f:
40+
f.write(py_code)
41+
42+
43+
def _get_input_name_and_constraint_attrs(input_name_and_meta_attrs):
44+
seq_no = 0
45+
dim2seq = {}
46+
47+
def find_or_new_seq(dim):
48+
nonlocal seq_no
49+
nonlocal dim2seq
50+
if dim in dim2seq:
51+
return dim2seq[dim]
52+
ret = seq_no
53+
dim2seq[dim] = ret
54+
seq_no += 1
55+
return ret
56+
57+
def make_symoblic_shape(shape):
58+
return type(shape)(
59+
[
60+
symbolic_dim_desc
61+
for dim in shape
62+
for dim_seq_no in [find_or_new_seq(dim)]
63+
for symbolic_dim_desc in [
64+
{"symbol_name": f"s{dim_seq_no}", "example_value": dim}
65+
]
66+
]
67+
)
68+
69+
def make_constraint_attrs(attrs):
70+
attrs = copy.deepcopy(attrs)
71+
attrs["shape"] = make_symoblic_shape(attrs["shape"])
72+
return attrs
73+
74+
return [
75+
(name, symbolic_attrs)
76+
for name, attrs in input_name_and_meta_attrs
77+
for symbolic_attrs in [make_constraint_attrs(attrs)]
78+
]
79+
80+
81+
def _get_input_name_and_meta_attrs(name2input_param_attrs, name_and_annotation_types):
82+
def constructed_from_self(name):
83+
return name.find("self_") != -1
84+
85+
def is_tensor_type(annotation_type):
86+
return annotation_type is torch.Tensor
87+
88+
ret = [
89+
(name, meta_attr)
90+
for name, annotation_type in name_and_annotation_types
91+
if is_tensor_type(annotation_type)
92+
if not constructed_from_self(name)
93+
for meta_attr in [name2input_param_attrs[name]]
94+
]
95+
assert len(ret) > 0
96+
return ret
97+
98+
99+
def _get_name_and_annotation_types(model_path):
100+
model_class = load_class_from_file(
101+
f"{model_path}/model.py", class_name="GraphModule"
102+
)
103+
annotations = inspect.getfullargspec(model_class.forward).annotations
104+
return [(k, v) for k, v in annotations.items()]
105+
106+
107+
def _get_name2input_param_attrs(model_path):
108+
def get_classes():
109+
input_meta_file = f"{model_path}/input_meta.py"
110+
for _, cls in imp_util.load_name_and_classes_from_file(input_meta_file):
111+
yield cls
112+
113+
weight_meta_file = f"{model_path}/weight_meta.py"
114+
for _, cls in imp_util.load_name_and_classes_from_file(weight_meta_file):
115+
yield cls
116+
117+
return {
118+
name: attr
119+
for cls in get_classes()
120+
for attr in [record_util.make_attrs_from_class(cls)]
121+
for name in [attr["name"]]
122+
}
123+
124+
125+
if __name__ == "__main__":
126+
parser = argparse.ArgumentParser(description="generate constraint proposal file")
127+
parser.add_argument(
128+
"--model-path",
129+
type=str,
130+
required=True,
131+
help="Path to folder e.g '../../samples/torch/resnet18'",
132+
)
133+
parser.add_argument(
134+
"--output-path",
135+
type=str,
136+
required=True,
137+
help="output file path",
138+
)
139+
args = parser.parse_args()
140+
main(args=args)

graph_net/torch/hash_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import hashlib
2+
3+
4+
def get_sha_hash(content):
5+
m = hashlib.sha256()
6+
m.update(content.encode())
7+
return m.hexdigest()

graph_net/torch/imp_util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import importlib.util
2+
import inspect
3+
4+
5+
def load_class_from_file(file_path: str, class_name: str):
6+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
7+
unnamed = importlib.util.module_from_spec(spec)
8+
spec.loader.exec_module(unnamed)
9+
model_class = getattr(unnamed, class_name, None)
10+
return model_class
11+
12+
13+
def load_name_and_classes_from_file(file_path):
14+
spec = importlib.util.spec_from_file_location("unnamed", file_path)
15+
unnamed = importlib.util.module_from_spec(spec)
16+
spec.loader.exec_module(unnamed)
17+
yield from inspect.getmembers(unnamed, inspect.isclass)

graph_net/torch/record_util.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import hash_util
2+
3+
4+
def make_attrs_from_class(cls):
5+
return {
6+
k: v
7+
for k, v in cls.__dict__.items()
8+
if not k.startswith("__") and not callable(v)
9+
}
10+
11+
12+
def serialize_to_py_code(attrs, class_prefix):
13+
assert isinstance(attrs, (tuple, list))
14+
15+
ret = "\n".join(
16+
_serialize_one_attr_to_py_code(attr, class_prefix) for attr in attrs
17+
)
18+
return ret
19+
20+
21+
def _serialize_one_attr_to_py_code(attr, class_prefix):
22+
hash_str = hash_util.get_sha_hash(str(attr))
23+
hash_str = hash_str[:32]
24+
indent = " " * 4
25+
ret = "\n".join(
26+
[
27+
f"class {class_prefix}{hash_str}:",
28+
*[f"{indent}{name} = {repr(value)}" for name, value in attr.items()],
29+
]
30+
)
31+
return f"{ret}\n\n"

graph_net/torch/single_device_runner.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import utils
1+
from graph_net.torch import utils
22
import argparse
33
import importlib.util
44
import inspect
@@ -10,6 +10,10 @@
1010
from graph_net.torch.extractor import extract
1111
import hashlib
1212
from contextlib import contextmanager
13+
import json
14+
import record_util
15+
import imp_util
16+
import os
1317

1418

1519
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
@@ -62,9 +66,14 @@ def main(args):
6266
kwargs = dict(name=args.extract_name, dynamic=False, **dump_graph_options)
6367
model = extract(**kwargs)(model)
6468

65-
inputs_params = utils.load_converted_from_text(f"{model_path}")
69+
inputs_params = utils.make_input_and_param_tensors_from_model_path(
70+
f"{model_path}"
71+
)
6672
params = inputs_params["weight_info"]
67-
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
73+
shape_modifier = _get_shape_modifier(args)
74+
state_dict = {
75+
k: utils.replay_tensor(v, shape_modifier) for k, v in params.items()
76+
}
6877

6978
explain = torch._dynamo.explain(model)(**state_dict)
7079
if explain.graph_count != 1 or len(explain.break_reasons) != 0:
@@ -76,10 +85,31 @@ def main(args):
7685
f"Graph extraction failed. The resulting graph is incomplete, broken into {explain.graph_count} subgraphs."
7786
)
7887

79-
y = model(**state_dict)[0]
88+
model(**state_dict)
89+
90+
91+
def _get_shape_modifier(cli_args):
92+
"""
93+
yield shape modifier from shape_modifiers.json in directory cli_args.model_path
94+
"""
95+
if not cli_args.enable_shape_patch:
96+
return lambda name, shape: shape
97+
shape_patch_file_path = f"{cli_args.model_path}/shape_patch.py"
98+
if not os.path.exists(shape_patch_file_path):
99+
return lambda name, shape: shape
100+
shape_modifier_data = [
101+
attrs
102+
for name, cls in imp_util.load_name_and_classes_from_file(shape_patch_file_path)
103+
for attrs in [record_util.make_attrs_from_class(cls)]
104+
]
105+
assert isinstance(shape_modifier_data, list)
106+
return _make_shape_modifier_impl(shape_modifier_data)
80107

81-
print(torch.argmin(y), torch.argmax(y))
82-
print(y.shape)
108+
109+
def _make_shape_modifier_impl(shape_modifier_data):
110+
name2new_shape = {attrs["name"]: attrs["shape"] for attrs in shape_modifier_data}
111+
print(f"{name2new_shape=}")
112+
return lambda name, shape: name2new_shape[name] if name in name2new_shape else shape
83113

84114

85115
if __name__ == "__main__":
@@ -110,5 +140,12 @@ def main(args):
110140
default=None,
111141
help="Extracted graph's name",
112142
)
143+
parser.add_argument(
144+
"--enable-shape-patch",
145+
type=bool,
146+
required=False,
147+
default=False,
148+
help="Enable extra inputs",
149+
)
113150
args = parser.parse_args()
114151
main(args=args)

graph_net/torch/test_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def get_model(args):
8686

8787

8888
def get_input_dict(args):
89-
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
89+
inputs_params = utils.make_input_and_param_tensors_from_model_path(
90+
f"{args.model_path}"
91+
)
9092
params = inputs_params["weight_info"]
9193
return {
9294
k: utils.replay_tensor(v).to(torch.device(args.device))

0 commit comments

Comments
 (0)