|
1 | 1 | import os |
2 | 2 | import json |
3 | | - |
4 | | -os.environ["ENABLE_CINN_IN_DY2ST"] = "0" |
5 | | -# os.environ["FLAGS_logging_trunc_pir_py_code"] = "1" |
6 | | -# os.environ["FLAGS_logging_pir_py_code_int_tensor_element_limit"] = "64" |
7 | | -os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump" |
| 3 | +import importlib.util |
8 | 4 |
|
9 | 5 | import paddle |
10 | 6 | from athena.module_op_unittests_for_graphnet import GraphnetSample, generate_samples |
11 | 7 | from graph_net.paddle import utils |
12 | 8 |
|
13 | 9 |
|
| 10 | +def load_class_from_file(file_path: str, class_name: str): |
| 11 | + print(f"Load {class_name} from {file_path}") |
| 12 | + spec = importlib.util.spec_from_file_location("unnamed", file_path) |
| 13 | + unnamed = importlib.util.module_from_spec(spec) |
| 14 | + spec.loader.exec_module(unnamed) |
| 15 | + model_class = getattr(unnamed, class_name, None) |
| 16 | + return model_class |
| 17 | + |
| 18 | + |
| 19 | +def write_to_file(filepath, content): |
| 20 | + print(f"Write to {filepath}") |
| 21 | + with open(filepath, "w") as f: |
| 22 | + f.write(content) |
| 23 | + |
| 24 | + |
| 25 | +def generate_model_wrapper_class(model_dump_path, data_arg_names): |
| 26 | + graph_module_wrapper_class_template = """ |
| 27 | +import paddle |
| 28 | +
|
| 29 | +class GraphModuleWrapper(paddle.nn.Layer): |
| 30 | + def __init__(self, graph_module): |
| 31 | + super().__init__() |
| 32 | + self.graph_module = graph_module |
| 33 | +
|
| 34 | + def set_parameters(self, **kwargs): |
| 35 | + for name, value in kwargs.items(): |
| 36 | + if isinstance(value, paddle.nn.parameter.Parameter): |
| 37 | + setattr(self, name, value) |
| 38 | +
|
| 39 | + def forward(self, ${DATA_ARG_NAMES}): |
| 40 | + param_dict = { name: param for name, param in self.named_parameters() } |
| 41 | + outputs = self.graph_module(${DATA_ARG_VALUE_PAIRS}, **param_dict) |
| 42 | + return outputs |
| 43 | +""" |
| 44 | + |
| 45 | + data_arg_value_pairs = [f"{name}={name}" for name in data_arg_names] |
| 46 | + graph_module_wrapper_class_code_str = graph_module_wrapper_class_template.replace( |
| 47 | + "${DATA_ARG_NAMES}", ", ".join(data_arg_names) |
| 48 | + ).replace("${DATA_ARG_VALUE_PAIRS}", ", ".join(data_arg_value_pairs)) |
| 49 | + print(graph_module_wrapper_class_code_str) |
| 50 | + |
| 51 | + file_path = os.path.join(model_dump_path, "graph_module_wrapper.py") |
| 52 | + write_to_file(file_path, graph_module_wrapper_class_code_str) |
| 53 | + model_class = load_class_from_file( |
| 54 | + file_path=file_path, class_name="GraphModuleWrapper" |
| 55 | + ) |
| 56 | + return model_class |
| 57 | + |
| 58 | + |
14 | 59 | # used as configuration of python -m graph_net.paddle.run_model |
15 | 60 | class RunModelDecorator: |
16 | 61 | def __init__(self, config): |
@@ -89,18 +134,47 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict): |
89 | 134 | # Get model dump path |
90 | 135 | old_flags = self.prepare_to_extract(model_dump_path) |
91 | 136 |
|
| 137 | + param_dict = { |
| 138 | + k: v |
| 139 | + for k, v in input_dict.items() |
| 140 | + if isinstance(v, paddle.nn.parameter.Parameter) |
| 141 | + } |
| 142 | + data_dict = { |
| 143 | + k: v |
| 144 | + for k, v in input_dict.items() |
| 145 | + if not isinstance(v, paddle.nn.parameter.Parameter) |
| 146 | + } |
| 147 | + |
| 148 | + input_spec = self.input_spec |
92 | 149 | if self.input_spec is None: |
93 | | - self.input_spec = [ |
| 150 | + input_spec = [ |
94 | 151 | paddle.static.InputSpec(value.shape, value.dtype, name=name) |
95 | | - for name, value in input_dict.items() |
| 152 | + for name, value in data_dict.items() |
96 | 153 | if isinstance(value, paddle.Tensor) |
97 | 154 | ] |
| 155 | + else: |
| 156 | + assert len(input_spec) == len(data_dict) |
| 157 | + |
| 158 | + if param_dict: |
| 159 | + model_wrapper_class = generate_model_wrapper_class( |
| 160 | + model_dump_path, data_dict.keys() |
| 161 | + ) |
| 162 | + wrapped_model = model_wrapper_class(self.model) |
| 163 | + wrapped_model.set_parameters(**param_dict) |
| 164 | + else: |
| 165 | + wrapped_model = self.model |
98 | 166 |
|
99 | 167 | # Run the static model |
100 | 168 | static_model = paddle.jit.to_static( |
101 | | - self.model, input_spec=self.input_spec, full_graph=True |
| 169 | + wrapped_model, |
| 170 | + input_spec=input_spec, |
| 171 | + full_graph=True, |
| 172 | + backend=None, |
102 | 173 | ) |
103 | | - static_model(**input_dict) |
| 174 | + static_model.eval() |
| 175 | + program = static_model.forward.concrete_program.main_program |
| 176 | + # print(program) |
| 177 | + static_model(**data_dict) |
104 | 178 |
|
105 | 179 | # Restore the environment |
106 | 180 | paddle.set_flags(old_flags) |
@@ -147,11 +221,6 @@ def translate_pir_program_to_sample_codes( |
147 | 221 | return self.subgraph_idx2samples |
148 | 222 |
|
149 | 223 | def write_sample_to_file(self, subgraph_path, sample): |
150 | | - def write_to_file(filepath, content): |
151 | | - print(f"Write to {filepath}") |
152 | | - with open(filepath, "w") as f: |
153 | | - f.write(content) |
154 | | - |
155 | 224 | if not os.path.exists(subgraph_path): |
156 | 225 | os.makedirs(subgraph_path, exist_ok=True) |
157 | 226 | write_to_file(f"{subgraph_path}/model.py", sample.model) |
|
0 commit comments