|
| 1 | +import abc |
| 2 | +import copy |
| 3 | +import inspect |
| 4 | + |
| 5 | + |
| 6 | +class SamplePass(abc.ABC): |
| 7 | + def __init__(self, config=None): |
| 8 | + if config is None: |
| 9 | + config = {} |
| 10 | + |
| 11 | + self._check_config_declaration_valid() |
| 12 | + self.config = self._make_config_by_config_declare(config) |
| 13 | + |
| 14 | + @abc.abstractmethod |
| 15 | + def declare_config(self): |
| 16 | + raise NotImplementedError() |
| 17 | + |
| 18 | + @abc.abstractmethod |
| 19 | + def __call__(self, rel_model_path: str): |
| 20 | + raise NotImplementedError() |
| 21 | + |
| 22 | + def _recursively_check_mixin_declare_config(self, base_class): |
| 23 | + from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin |
| 24 | + |
| 25 | + if issubclass(base_class, (SamplePass, SamplePassMixin)): |
| 26 | + check_is_base_signature( |
| 27 | + base_class=base_class, |
| 28 | + derived_class=type(self), |
| 29 | + method_name="declare_config", |
| 30 | + ) |
| 31 | + for sub_class in base_class.__bases__: |
| 32 | + self._recursively_check_mixin_declare_config(sub_class) |
| 33 | + |
| 34 | + def _check_config_declaration_parameters(self): |
| 35 | + sig = inspect.signature(self.declare_config) |
| 36 | + for name, param in sig.parameters.items(): |
| 37 | + assert param.annotation in { |
| 38 | + int, |
| 39 | + bool, |
| 40 | + float, |
| 41 | + str, |
| 42 | + list, |
| 43 | + dict, |
| 44 | + }, f"{name=} {param.annotation}" |
| 45 | + assert param.kind in { |
| 46 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 47 | + inspect.Parameter.VAR_KEYWORD, |
| 48 | + }, f"{name=} {param.kind=}" |
| 49 | + |
| 50 | + def _check_config_declaration_valid(self): |
| 51 | + self._recursively_check_mixin_declare_config(type(self)) |
| 52 | + self._check_config_declaration_parameters() |
| 53 | + |
| 54 | + def _make_config_by_config_declare(self, config): |
| 55 | + sig = inspect.signature(self.declare_config) |
| 56 | + mut_config = copy.deepcopy(config) |
| 57 | + for name, param in sig.parameters.items(): |
| 58 | + self._complete_default(name, param, mut_config) |
| 59 | + class_name = type(self).__name__ |
| 60 | + assert name in mut_config, f"{name=} {class_name=}" |
| 61 | + |
| 62 | + def get_extra_config_fields(): |
| 63 | + return set(name for name, _ in mut_config.items()) - set( |
| 64 | + name for name, _ in sig.parameters.items() |
| 65 | + ) |
| 66 | + |
| 67 | + no_varadic_keyword = all( |
| 68 | + param.kind != inspect.Parameter.VAR_KEYWORD |
| 69 | + for _, param in sig.parameters.items() |
| 70 | + ) |
| 71 | + if no_varadic_keyword: |
| 72 | + no_extra_config_fields = all( |
| 73 | + name in sig.parameters for name, _ in mut_config.items() |
| 74 | + ) |
| 75 | + assert no_extra_config_fields, f"{get_extra_config_fields()=}" |
| 76 | + return mut_config |
| 77 | + |
| 78 | + def _complete_default(self, name, param, mut_config): |
| 79 | + if param.default is inspect.Parameter.empty: |
| 80 | + return |
| 81 | + mut_config[name] = copy.deepcopy(param.default) |
| 82 | + |
| 83 | + |
| 84 | +def check_is_base_signature(base_class, derived_class, method_name): |
| 85 | + base = getattr(base_class, method_name) |
| 86 | + derived = getattr(derived_class, method_name) |
| 87 | + base_parameters = inspect.signature(base).parameters |
| 88 | + derived_parameters = inspect.signature(derived).parameters |
| 89 | + assert len(derived_parameters) >= len(base_parameters) |
| 90 | + for name, param in base_parameters.items(): |
| 91 | + assert name in base_parameters, f"{name=}" |
| 92 | + assert param == base_parameters[name] |
0 commit comments