Skip to content

Commit e21af2b

Browse files
committed
check_graph_module_parsable
1 parent babdde5 commit e21af2b

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
samples/timm/aimv2_huge_patch14_224.apple_pt
2+
samples/timm/vit_so400m_patch14_siglip_gap_224.pali2_3b_pt
3+
samples/timm/eca_resnet33ts.ra2_in1k
4+
samples/timm/poolformer_s12.sail_in1k
5+
samples/timm/eva_giant_patch14_224.clip_ft_in1k
6+
samples/timm/eva02_base_patch14_224.mim_in22k
7+
samples/timm/eva02_small_patch14_224.mim_in22k
8+
samples/timm/poolformer_m36.sail_in1k
9+
samples/timm/poolformerv2_m36.sail_in1k
10+
samples/timm/resnest50d_1s4x24d.in1k
11+
samples/timm/aimv2_1b_patch14_224.apple_pt
12+
samples/timm/poolformer_m48.sail_in1k
13+
samples/timm/poolformer_s36.sail_in1k
14+
samples/timm/vit_base_patch16_siglip_gap_224.v2_webli
15+
samples/timm/vit_giant_patch16_gap_224.in22k_ijepa
16+
samples/timm/ecaresnet101d.miil_in1k
17+
samples/timm/vit_huge_patch14_gap_224.in1k_ijepa
18+
samples/timm/aimv2_3b_patch14_224.apple_pt
19+
samples/timm/skresnet18.ra_in1k
20+
samples/timm/ecaresnetlight.miil_in1k
21+
samples/timm/eca_resnext26ts.ch_in1k
22+
samples/timm/skresnext50_32x4d.ra_in1k
23+
samples/timm/poolformerv2_m48.sail_in1k
24+
samples/timm/poolformer_s24.sail_in1k
25+
samples/timm/ecaresnet26t.ra2_in1k
26+
samples/timm/eva02_large_patch14_224.mim_in22k
27+
samples/timm/ecaresnet50d.miil_in1k
28+
samples/timm/poolformerv2_s36.sail_in1k
29+
samples/timm/aimv2_large_patch14_224.apple_pt
30+
samples/timm/skresnet34.ra_in1k
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4+
os.path.dirname(graph_net.__file__))")
5+
6+
# input model path
7+
# model_runnable_predicator=ShapePropagatablePredicator
8+
model_runnable_predicator=ModelRunnablePredicator
9+
config_json_str=$(cat <<EOF
10+
{
11+
"handler_path": "$GRAPH_NET_ROOT/torch/check_graph_module_parsable.py",
12+
"handler_class_name": "CheckGraphModuleParsable",
13+
"handler_config": {
14+
"model_path_prefix": "$GRAPH_NET_ROOT/../",
15+
"resume": true,
16+
"limits_handled_models": 999999,
17+
"output_dir": "/tmp/check_graph_module_parsable"
18+
}
19+
}
20+
EOF
21+
)
22+
CONFIG=$(echo $config_json_str | base64 -w 0)
23+
24+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/test/dev_model_list/graph_module_parse_error_torch_sample_list.txt --handler-config=$CONFIG --use-subprocess
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from pathlib import Path
2+
from graph_net.torch.fx_graph_cache_util import (
3+
parse_immutable_model_path_into_sole_graph_module,
4+
)
5+
import os
6+
import sys
7+
8+
9+
class CheckGraphModuleParsable:
10+
def __init__(self, config=None):
11+
if config is None:
12+
config = {}
13+
self.config = self._make_config(**config)
14+
self.num_handled_models = 0
15+
16+
def _make_config(
17+
self,
18+
model_path_prefix,
19+
output_dir,
20+
resume=False,
21+
limits_handled_models=None,
22+
):
23+
return {
24+
"model_path_prefix": model_path_prefix,
25+
"output_dir": output_dir,
26+
"resume": resume,
27+
"limits_handled_models": limits_handled_models,
28+
}
29+
30+
def __call__(self, rel_model_path):
31+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
32+
if self.config["resume"] and self._is_model_path_handled(rel_model_path):
33+
return
34+
parse_immutable_model_path_into_sole_graph_module(model_path)
35+
output_dir = Path(self.config["output_dir"]) / rel_model_path
36+
output_dir.mkdir(parents=True, exist_ok=True)
37+
self._inc_num_handled_models()
38+
39+
def _is_model_path_handled(self, rel_model_path):
40+
return (Path(self.config["output_dir"]) / rel_model_path).exists()
41+
42+
def _inc_num_handled_models(self):
43+
self.num_handled_models += 1
44+
limits = self.config["limits_handled_models"]
45+
if limits is not None:
46+
if self.num_handled_models >= limits:
47+
print(
48+
"`num_handled_models` exceeds config `limits_handled_models`",
49+
file=sys.stderr,
50+
)
51+
sys.exit(0)

0 commit comments

Comments
 (0)