Skip to content

Commit 9188178

Browse files
authored
Fix the model path in graph_variable_renamer_validator_backend.py (#449)
1 parent 231824f commit 9188178

File tree

2 files changed

+44
-40
lines changed

2 files changed

+44
-40
lines changed
Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,46 @@
11
#!/bin/bash
22

3-
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
4-
os.path.dirname(graph_net.__file__))")
5-
WORKSPACE=/tmp/graph_variable_rename_workspace
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
RENAMED_PATH=/tmp/graph_variable_rename_workspace
65

7-
# input model path
8-
MODEL_NAME=resnet18
9-
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
10-
config_json_str=$(cat <<EOF
6+
mkdir -p "$RENAMED_PATH"
7+
model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
8+
9+
python3 -m graph_net.model_path_handler \
10+
--model-path-list $model_list \
11+
--handler-config=$(base64 -w 0 <<EOF
1112
{
12-
"handler_path": "$GRAPH_NET_ROOT/torch/graph_variable_renamer.py",
13+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_variable_renamer.py",
1314
"handler_class_name": "GraphVariableRenamer",
1415
"handler_config": {
15-
"model_path_prefix": "$GRAPH_NET_ROOT/../",
16-
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
16+
"device": "cuda",
17+
"resume": true,
18+
"model_path_prefix": "$GRAPH_NET_ROOT/",
19+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
1720
"data_input_predicator_class_name": "NaiveDataInputPredicator",
18-
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
21+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
1922
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
20-
"output_dir": "$WORKSPACE"
23+
"output_dir": "$RENAMED_PATH"
2124
}
2225
}
2326
EOF
24-
)
25-
CONFIG=$(echo $config_json_str | base64 -w 0)
26-
27-
python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
28-
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
27+
) \
28+
2>&1 | tee "$RENAMED_PATH/graph_rename.log"
2929

30-
test_compiler_config_json_str=$(cat <<EOF
30+
python3 -m graph_net.torch.test_compiler \
31+
--model-path-prefix $GRAPH_NET_ROOT \
32+
--allow-list $model_list \
33+
--compiler graph_variable_renamer_validator \
34+
--device cuda \
35+
--config $(base64 -w 0 <<EOF
3136
{
3237
"model_path_prefix": "$GRAPH_NET_ROOT",
33-
"renamed_root": "$WORKSPACE"
38+
"renamed_root": "$RENAMED_PATH"
3439
}
3540
EOF
36-
)
37-
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)
38-
39-
python3 -m graph_net.torch.test_compiler \
40-
--model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES \
41-
--compiler graph_variable_renamer_validator \
42-
--device cuda \
43-
--config $TEST_COMPILER_CONFIG \
44-
> "$WORKSPACE/validation.log" 2>&1
41+
) \
42+
2>&1 | tee "$RENAMED_PATH/validation.log"
4543

4644
python3 -m graph_net.plot_ESt \
47-
--benchmark-path "$WORKSPACE/validation.log" \
48-
--output-dir "$WORKSPACE"
45+
--benchmark-path "$RENAMED_PATH/validation.log" \
46+
--output-dir "$RENAMED_PATH"

graph_net/torch/backend/graph_variable_renamer_validator_backend.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,32 +56,38 @@ def _make_config(
5656
self,
5757
model_path_prefix: str,
5858
renamed_root: str,
59-
renamed_dentry: str = "_renamed",
6059
):
6160
return {
6261
"model_path_prefix": model_path_prefix,
6362
"renamed_root": renamed_root,
64-
"renamed_dentry": renamed_dentry,
6563
}
6664

65+
def _get_rel_model_path(self, model_path) -> str:
66+
model_path = os.path.realpath(model_path)
67+
model_path_prefix = os.path.realpath(self.config["model_path_prefix"])
68+
assert model_path.startswith(model_path_prefix)
69+
rel_model_path = model_path[len(model_path_prefix) :]
70+
if rel_model_path.startswith("/"):
71+
rel_model_path = rel_model_path[1:]
72+
assert not rel_model_path.startswith("/")
73+
return rel_model_path
74+
6775
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
6876
config = self._make_config(**self.config)
6977
model_path = os.path.dirname(model.__class__.__graph_net_file_path__)
7078
model_name = os.path.basename(model_path)
71-
renamed_dir_name = f"{model_name}_renamed"
72-
renamed_model_dir = os.path.join(config["renamed_root"], renamed_dir_name)
79+
rel_model_path = self._get_rel_model_path(model_path)
80+
renamed_parent_dir = os.path.join(config["renamed_root"], rel_model_path)
7381

7482
print(f"[GraphVariableRenamerValidatorBackend] Processing: {model_name}")
7583
print(
76-
f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_model_dir}"
84+
f"[GraphVariableRenamerValidatorBackend] Loading from: {renamed_parent_dir}"
7785
)
7886

7987
device = model.__class__.__graph_net_device__
80-
renamed_model = self._load_model_instance(renamed_model_dir, device)
81-
mapping = self._get_rename_mapping(Path(renamed_model_dir))
82-
assert (
83-
mapping
84-
), f"Mapping is empty for {renamed_dir_name} at {renamed_model_dir}"
88+
renamed_model = self._load_model_instance(renamed_parent_dir, device)
89+
mapping = self._get_rename_mapping(Path(renamed_parent_dir))
90+
assert mapping, f"Mapping is empty for {model_name} at {renamed_parent_dir}"
8591
adapter = RenamedModelAdapter(renamed_model, mapping)
8692
return adapter.eval()
8793

0 commit comments

Comments
 (0)