2323from core .testcasecontroller .generation_assistant import get_full_combinations
2424
2525
26+ # pylint: disable=too-few-public-methods
2627class Module :
2728 """
2829 Algorithm Module:
@@ -52,8 +53,8 @@ def __init__(self, config):
5253 self .type : str = ""
5354 self .name : str = ""
5455 self .url : str = ""
55- self .hyperparameters = None
56- self .hyperparameters_list = None
56+ self .hyperparameters = {}
57+ self .hyperparameters_list = []
5758 self ._parse_config (config )
5859
5960 def _check_fields (self ):
@@ -71,75 +72,59 @@ def _check_fields(self):
7172 if not isinstance (self .url , str ):
7273 raise ValueError (f"module url({ self .url } ) must be string type." )
7374
74- def basemodel_func (self ):
75+ def get_module_instance (self , module_type ):
7576 """
76- get basemodel module function of the module.
77+ get function of algorithm module by using module type
78+
79+ Parameters
80+ ---------
81+ module_type: string
82+ module type, e.g.: basemodel, hard_example_mining, etc.
7783
7884 Returns
79- --------
85+ ------
8086 function
8187
8288 """
89+ class_factory_type = ClassType .GENERAL
90+ if module_type in [ModuleType .HARD_EXAMPLE_MINING .value ]:
91+ class_factory_type = ClassType .HEM
8392
84- if not self .url :
85- raise ValueError (f"url({ self .url } ) of basemodel module must be provided." )
93+ elif module_type in [ModuleType .TASK_DEFINITION .value ,
94+ ModuleType .TASK_RELATIONSHIP_DISCOVERY .value ,
95+ ModuleType .TASK_REMODELING .value ,
96+ ModuleType .TASK_ALLOCATION .value ,
97+ ModuleType .INFERENCE_INTEGRATE .value ]:
98+ class_factory_type = ClassType .STP
8699
87- try :
88- utils .load_module (self .url )
89- # pylint: disable=E1134
90- basemodel = ClassFactory .get_cls (type_name = ClassType .GENERAL ,
91- t_cls_name = self .name )(** self .hyperparameters )
92- except Exception as err :
93- raise RuntimeError (f"basemodel module loads class(name={ self .name } ) failed, "
94- f"error: { err } ." ) from err
100+ elif module_type in [ModuleType .TASK_UPDATE_DECISION .value ]:
101+ class_factory_type = ClassType .KM
95102
96- return basemodel
103+ elif module_type in [ModuleType .UNSEEN_TASK_ALLOCATION .value ]:
104+ class_factory_type = ClassType .UTP
97105
98- def hard_example_mining_func (self ):
99- """
100- get hard example mining function of the module.
101-
102- Returns:
103- --------
104- function
105-
106- """
106+ elif module_type in [ModuleType .UNSEEN_SAMPLE_RECOGNITION .value ,
107+ ModuleType .UNSEEN_SAMPLE_RE_RECOGNITION .value ]:
108+ class_factory_type = ClassType .UTD
107109
108110 if self .url :
109111 try :
110112 utils .load_module (self .url )
111113 # pylint: disable=E1134
112114 func = ClassFactory .get_cls (
113- type_name = ClassType . HEM , t_cls_name = self .name )(** self .hyperparameters )
115+ type_name = class_factory_type , t_cls_name = self .name )(** self .hyperparameters )
114116
115117 return func
116118 except Exception as err :
117- raise RuntimeError (f"hard_example_mining module loads class"
118- f"(name= { self . name } ) failed, error: { err } ." ) from err
119+ raise RuntimeError (f"module(type= { module_type } loads class(name= { self . name } ) "
120+ f"failed, error: { err } ." ) from err
119121
120- # call built-in hard example mining function
121- hard_example_mining = {"method" : self .name }
122+ # call lib built-in module function
123+ module_func = {"method" : self .name }
122124 if self .hyperparameters :
123- hard_example_mining ["param" ] = self .hyperparameters
124-
125- return hard_example_mining
125+ module_func ["param" ] = self .hyperparameters
126126
127- def get_module_func (self , module_type ):
128- """
129- get function of algorithm module by using module type
130-
131- Parameters
132- ---------
133- module_type: string
134- module type, e.g.: basemodel, hard_example_mining, etc.
135-
136- Returns
137- ------
138- function
139-
140- """
141- func_name = f"{ module_type } _func"
142- return getattr (self , func_name )
127+ return module_func
143128
144129 def _parse_config (self , config ):
145130 # pylint: disable=C0103
0 commit comments