Skip to content

Commit fe89add

Browse files
committed
remove unnecessary code blocks and variables
1 parent 75c3e61 commit fe89add

File tree

3 files changed

+4
-69
lines changed

3 files changed

+4
-69
lines changed

graph_net/torch/naive_graph_decomposer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,36 +92,28 @@ def forward(self, *args):
9292
if not self.extracted:
9393
if self.need_extract(self.submodule, args):
9494
self.builtin_extractor(self.submodule, args)
95-
self.get_post_extract_process(self.submodule, args)
95+
self.get_post_extract_process()
9696
self.extracted = True
9797
return self.submodule(*args)
9898

9999
def need_extract(self, gm, sample_inputs):
100-
# print("need_extract")
101100
if self.filter is None:
102101
return True
103-
# if self.fusionablity_filter is not None:
104-
# print("fusionablity of this model is ", self.fusionablity_filter(gm, sample_inputs))
105102
return self.filter(gm, sample_inputs)
106103

107-
def get_post_extract_process(self, gm, sample_inputs):
108-
# print("modelname: ",self.modelname)
109-
# print("parent_graph_extractor.config: ",self.parent_graph_extractor.config['output_dir'])
110-
# print("get_post_extract_process")
104+
def get_post_extract_process(self):
111105
model_path = os.path.join(
112106
self.parent_graph_extractor.config["output_dir"], self.modelname
113107
)
114108
return self.post_extract_process(model_path)
115109

116110
def make_filter(self, config):
117-
# print("make_filter")
118111
if config["filter_path"] is None:
119112
return None
120113
module = imp_util.load_module(config["filter_path"])
121114
return module.GraphFilter(config["filter_config"])
122115

123116
def make_post_extract_process(self, config):
124-
# print("make post_extract_process")
125117
if config["filter_path"] is None:
126118
return None
127119
module = imp_util.load_module(config["post_extract_process_path"])

graph_net/torch/naive_subgraph_filter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ def __init__(self, config):
33
self.config = config
44

55
def __call__(self, gm, sample_inputs):
6-
print("GraphFilter")
76
# print(f"GraphFilter\n{gm.code}")
87
return True

graph_net/torch/post_extract_process.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,8 @@ def __init__(self, config):
2020
self.config = config
2121

2222
def __call__(self, model_path=None):
23-
print("PostExtractProcess")
2423
if model_path is None:
2524
return False
26-
import json
27-
import base64
28-
import sys
29-
import os
30-
31-
json_string = json.dumps(self.config)
32-
json_bytes = json_string.encode("utf-8")
33-
b64_encoded_bytes = base64.b64encode(json_bytes)
34-
decorator_config = b64_encoded_bytes.decode("utf-8")
35-
36-
# args
37-
parser = argparse.ArgumentParser(description="load and run model")
38-
parser.add_argument(
39-
"--model-path",
40-
type=str,
41-
required=True,
42-
help="Path to folder e.g '../../samples/torch/resnet18'",
43-
)
44-
parser.add_argument(
45-
"--decorator-config",
46-
type=str,
47-
required=False,
48-
default=None,
49-
help="decorator configuration string",
50-
)
51-
args = parser.parse_args()
52-
5325
# model
5426
model_class = load_class_from_file(
5527
f"{model_path}/model.py", class_name="GraphModule"
@@ -58,46 +30,20 @@ def __call__(self, model_path=None):
5830
model = model_class()
5931
print(f"{model_path=}")
6032

61-
model = _get_decorator(args)(model)
62-
6333
inputs_params = utils.load_converted_from_text(f"{model_path}")
6434
params = inputs_params["weight_info"]
6535
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
6636

6737
compiled_num_of_kernels = compile_and_count_kernels(model, state_dict)
68-
print("compiled: nums_of_kernels = ", compiled_num_of_kernels)
6938
if compiled_num_of_kernels == 1:
70-
print("Graph is fully fusionable")
39+
print(model_path, "can be fully integrated")
7140
return True
7241
else:
73-
print(f"Graph is not fully fusionable ({compiled_num_of_kernels} kernels)")
42+
print(model_path, "can not be fully integrated")
7443
shutil.rmtree(model_path)
7544
return False
7645

7746

78-
def _convert_to_dict(config_str):
79-
if config_str is None:
80-
return {}
81-
config_str = base64.b64decode(config_str).decode("utf-8")
82-
config = json.loads(config_str)
83-
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
84-
return config
85-
86-
87-
def _get_decorator(args):
88-
if args.decorator_config is None:
89-
return lambda model: model
90-
decorator_config = _convert_to_dict(args.decorator_config)
91-
if "decorator_path" not in decorator_config:
92-
return lambda model: model
93-
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
94-
decorator_class = load_class_from_file(
95-
decorator_config["decorator_path"],
96-
class_name=class_name,
97-
)
98-
return decorator_class(decorator_config.get("decorator_config", {}))
99-
100-
10147
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
10248
spec = importlib.util.spec_from_file_location("unnamed", file_path)
10349
unnamed = importlib.util.module_from_spec(spec)
@@ -133,11 +79,9 @@ def compile_and_count_kernels(gm, sample_inputs) -> int:
13379
) as prof:
13480
with record_function("model_inference"):
13581
output = compiled_gm(**sample_inputs)
136-
print(prof.key_averages().table()) # print a table of profiler result
13782
events = prof.key_averages()
13883
if_compile_work = any(e.key == "TorchDynamo Cache Lookup" for e in events)
13984
if not if_compile_work:
140-
print("Compile failed")
14185
return -1
14286
for e in events:
14387
if e.key == "cuLaunchKernel":

0 commit comments

Comments
 (0)