Skip to content

Commit de54e88

Browse files
committed
modify script, rename files and variables
1 parent c21717f commit de54e88

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

graph_net/test/naive_decomposer_and_post_extract_process_test.sh

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
#!/bin/bash
2-
SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
3-
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
4-
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")
5-
6-
# 将项目根目录加入Python路径
7-
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
2+
# bash graph_net/test/naive_decomposer_and_post_extract_process_test.sh
83

94
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
105
os.path.dirname(graph_net.__file__))")
@@ -19,12 +14,13 @@ decorator_config_json_str=$(cat <<EOF
1914
"name": "$MODEL_NAME",
2015
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
2116
"custom_extractor_config": {
22-
"output_dir": "/work/.BCloud/countkernels/",
17+
"output_dir": "/tmp/naive_decompose_workspace",
2318
"split_positions": [8, 16, 32],
2419
"group_head_and_tail": true,
2520
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
2621
"filter_config": {},
27-
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process.py"
22+
"post_extract_process_path":"$GRAPH_NET_ROOT/torch/post_extract_process_count_kernels.py",
23+
"post_extract_process_class_name": "PostExtractProcess"
2824
}
2925
}
3026
}

graph_net/torch/naive_graph_decomposer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def make_config(
3333
filter_path=None,
3434
filter_config=None,
3535
post_extract_process_path=None,
36-
post_extract_process_config=None,
36+
post_extract_process_class_name=None,
3737
):
3838
for pos in split_positions:
3939
assert isinstance(
@@ -47,7 +47,7 @@ def make_config(
4747
"filter_path": filter_path,
4848
"filter_config": filter_config if filter_config is not None else {},
4949
"post_extract_process_path": post_extract_process_path,
50-
"post_extract_process_config": post_extract_process_config,
50+
"post_extract_process_class_name": post_extract_process_class_name,
5151
}
5252

5353
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
@@ -75,7 +75,7 @@ def __init__(self, parent_graph_extractor, submodule, seq_no):
7575
self.seq_no = seq_no
7676
self.extracted = False
7777
name = f"{parent_graph_extractor.name}_{self.seq_no}"
78-
self.modelname = name
78+
self.model_name = name
7979
self.builtin_extractor = BuiltinGraphExtractor(
8080
name=name,
8181
dynamic=False,
@@ -103,7 +103,7 @@ def need_extract(self, gm, sample_inputs):
103103

104104
def _post_extract_process(self):
105105
model_path = os.path.join(
106-
self.parent_graph_extractor.config["output_dir"], self.modelname
106+
self.parent_graph_extractor.config["output_dir"], self.model_name
107107
)
108108
return self.post_extract_process(model_path)
109109

@@ -117,4 +117,4 @@ def make_post_extract_process(self, config):
117117
if config["post_extract_process_path"] is None:
118118
return None
119119
module = imp_util.load_module(config["post_extract_process_path"])
120-
return module.PostExtractProcess(config["post_extract_process_config"])
120+
return module.PostExtractProcess(config["post_extract_process_path"])

graph_net/torch/post_extract_process.py renamed to graph_net/torch/post_extract_process_count_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __call__(self, model_path=None):
2727

2828
model(**state_dict)
2929
compiled_model = torch.compile(model)
30-
compiled_num_of_kernels = count_kernels(model, state_dict)
30+
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
3131
if compiled_num_of_kernels == 1:
3232
print(model_path, "can be fully integrated")
3333
return True

0 commit comments

Comments
 (0)