Skip to content

Commit 6cf5a0a

Browse files
committed
batch initial input_tensor_constraints.py
1 parent 8b68398 commit 6cf5a0a

File tree

7 files changed

+3633
-4
lines changed

7 files changed

+3633
-4
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
samples/timm/resnetaa50d.d_in12k
2+
samples/timm/regnetx_016.pycls_in1k
3+
samples/timm/repghostnet_130.in1k

graph_net/constraint_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _make_config(
3535
data_input_predicator_config=None,
3636
model_runnable_predicator_class_name="ModelRunner",
3737
model_runnable_predicator_config=None,
38+
model_path_prefix="",
3839
):
3940
if data_input_predicator_config is None:
4041
data_input_predicator_config = {}
@@ -47,9 +48,11 @@ def _make_config(
4748
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
4849
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
4950
"model_runnable_predicator_config": model_runnable_predicator_config,
51+
"model_path_prefix": model_path_prefix,
5052
}
5153

5254
def __call__(self, model_path):
55+
model_path = os.path.join(self.config["model_path_prefix"], model_path)
5356
tensor_metas = self._get_tensor_metas(model_path)
5457
dyn_dim_cstr = make_dyn_dim_cstr_from_tensor_metas(tensor_metas)
5558

graph_net/model_path_handler.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,42 @@ def _get_handler(args):
3636

3737

3838
def main(args):
39-
model_path = args.model_path
40-
print(f"{model_path=}")
39+
handler = _get_handler(args)
40+
for model_path in _get_model_paths(args):
41+
print(f"{model_path=}")
42+
handler(model_path)
4143

42-
_get_handler(args)(model_path)
44+
45+
def _get_model_paths(args):
46+
assert args.model_path is not None or args.model_path_list is not None
47+
if args.model_path is not None:
48+
yield args.model_path
49+
if args.model_path_list is not None:
50+
with open(args.model_path_list) as f:
51+
yield from (
52+
clean_line
53+
for line in f
54+
for clean_line in [line.strip()]
55+
if len(clean_line) > 0
56+
)
4357

4458

4559
if __name__ == "__main__":
4660
parser = argparse.ArgumentParser(description="model path handler entry")
4761
parser.add_argument(
4862
"--model-path",
4963
type=str,
50-
required=True,
64+
required=False,
65+
default=None,
5166
help="Path to folder e.g '../../samples/torch/resnet18'",
5267
)
68+
parser.add_argument(
69+
"--model-path-list",
70+
type=str,
71+
required=False,
72+
default=None,
73+
help="Path of file containing model paths.",
74+
)
5375
parser.add_argument(
5476
"--handler-config",
5577
type=str,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
MODEL_NAME=resnet18
8+
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/constraint_util.py",
12+
"handler_class_name": "UpdateInputTensorConstraints",
13+
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
15+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
16+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
17+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
18+
"model_runnable_predicator_class_name": "ModelRunnablePredicator"
19+
}
20+
}
21+
EOF
22+
)
23+
CONFIG=$(echo $config_json_str | base64 -w 0)
24+
25+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/small_torch_samples_list.txt --handler-config=$CONFIG

0 commit comments

Comments
 (0)