|
6 | 6 | import argparse |
7 | 7 | import importlib |
8 | 8 | import inspect |
| 9 | +import ast |
9 | 10 | import paddle |
10 | 11 |
|
11 | 12 |
|
@@ -112,15 +113,16 @@ def load_converted_list_from_text(file_path): |
112 | 113 | input_info = [ |
113 | 114 | data for data in convert_meta_classes_to_tensors(f"{file_path}/input_meta.py") |
114 | 115 | ] |
| 116 | + # print(f"-- input_info: {input_info}") |
115 | 117 | weight_info = [ |
116 | 118 | data for data in convert_meta_classes_to_tensors(f"{file_path}/weight_meta.py") |
117 | 119 | ] |
118 | | - |
119 | | - return [*input_info, *weight_info] |
| 120 | + return [*weight_info, *input_info] |
120 | 121 |
|
121 | 122 |
|
122 | 123 | def convert_meta_classes_to_tensors(file_path): |
123 | 124 | for name, cls in _get_classes(file_path): |
| 125 | + # print(f"-- name: {name}") |
124 | 126 | attrs = { |
125 | 127 | k: v |
126 | 128 | for k, v in cls.__dict__.items() |
@@ -152,10 +154,18 @@ def convert_meta_classes_to_tensors(file_path): |
152 | 154 |
|
153 | 155 |
|
154 | 156 | def _get_classes(file_path): |
| 157 | + with open(file_path, "r", encoding="utf-8") as f: |
| 158 | + tree = ast.parse(f.read(), filename=file_path) |
| 159 | + |
| 160 | + class_names = [node.name for node in tree.body if isinstance(node, ast.ClassDef)] |
| 161 | + |
155 | 162 | spec = importlib.util.spec_from_file_location("unnamed", file_path) |
156 | 163 | unnamed = importlib.util.module_from_spec(spec) |
157 | 164 | spec.loader.exec_module(unnamed) |
158 | | - yield from inspect.getmembers(unnamed, inspect.isclass) |
| 165 | + # yield from inspect.getmembers(unnamed, inspect.isclass) |
| 166 | + |
| 167 | + classes = [(name, getattr(unnamed, name)) for name in class_names] |
| 168 | + return classes |
159 | 169 |
|
160 | 170 |
|
161 | 171 | def extract_dynamic_shapes(example_inputs): |
|
0 commit comments