11import argparse
2+ from inspect import getmodule , getmembers , isfunction
23import os
34from pathlib import Path
45import platform
1011from niaarm import NiaARM , Dataset , Stats
1112from niapy .task import OptimizationType , Task
1213from niapy .util .factory import get_algorithm
14+ from niapy .util import distances , repair
15+ from niapy .algorithms .other import mts
16+ from niapy .algorithms .basic import de
1317
1418
1519def get_parser ():
@@ -26,7 +30,7 @@ def get_parser():
2630 parser .add_argument ('--beta' , type = float , default = 0.0 , help = 'Beta parameter. Default 0' )
2731 parser .add_argument ('--gamma' , type = float , default = 0.0 , help = 'Gamma parameter. Default 0' )
2832 parser .add_argument ('--delta' , type = float , default = 0.0 , help = 'Delta parameter. Default 0' )
29- parser .add_argument ('--logging ' , action = 'store_true' , help = 'Enable logging of fitness improvements' )
33+ parser .add_argument ('--log ' , action = 'store_true' , help = 'Enable logging of fitness improvements' )
3034 parser .add_argument ('--show-stats' , action = 'store_true' , help = 'Display stats about mined rules' )
3135
3236 return parser
@@ -41,28 +45,67 @@ def parameters_string(parameters):
4145 '# Save and exit to continue\n ' \
4246 '# WARNING: Do not edit parameter names\n '
4347 for parameter , value in parameters .items ():
44- params_txt += f'{ parameter } = { value } \n '
48+ if isinstance (value , tuple ):
49+ if callable (value [0 ]):
50+ value = tuple (v .__name__ for v in value )
51+ else :
52+ value = tuple (str (v ) for v in value )
53+ value = ', ' .join (value )
54+ params_txt += f'{ parameter } = { value .__name__ if callable (value ) else value } \n '
4555 return params_txt
4656
4757
48- def parse_parameters (text ):
58+ def functions (algorithm ):
59+ funcs = {}
60+ algorithm_funcs = dict (getmembers (getmodule (algorithm .__class__ ), isfunction ))
61+ repair_funcs = dict (getmembers (repair , isfunction ))
62+ distance_funcs = dict (getmembers (distances , isfunction ))
63+ de_funcs = dict (getmembers (de , isfunction ))
64+ mts_funcs = dict (getmembers (mts , isfunction ))
65+ funcs .update (algorithm_funcs )
66+ funcs .update (repair_funcs )
67+ funcs .update (distance_funcs )
68+ funcs .update (de_funcs )
69+ funcs .update (mts_funcs )
70+ return funcs
71+
72+
73+ def find_function (name , algorithm ):
74+ return functions (algorithm )[name ]
75+
76+
77+ def convert_string (string ):
78+ try :
79+ value = float (string )
80+ if value .is_integer ():
81+ value = int (value )
82+ except ValueError :
83+ return string
84+ return value
85+
86+
87+ def parse_parameters (text , algorithm ):
4988 lines : list [str ] = text .strip ().split ('\n ' )
5089 lines = [line .strip () for line in lines if line .strip () and not line .strip ().startswith ('#' )]
5190 parameters = {}
5291 for line in lines :
5392 key , value = line .split ('=' )
5493 key = key .strip ()
55- try :
56- value = float (value .strip ())
57- if value .is_integer ():
58- value = int (value )
59- except ValueError :
60- pass
94+ value = convert_string (value .strip ())
95+ if isinstance (value , str ):
96+ if len (value .split (', ' )) > 1 : # tuple
97+ value = list (map (str .strip , value .split (', ' )))
98+ value = tuple (map (convert_string , value ))
99+ value = tuple (find_function (v , algorithm ) for v in value if type (v ) == str )
100+ elif value .lower () == 'true' or value .lower () == 'false' : # boolean
101+ value = value .lower () == 'true'
102+ else : # probably a function
103+ value = find_function (value , algorithm )
61104 parameters [key ] = value
62105 return parameters
63106
64107
65- def edit_parameters (parameters ):
108+ def edit_parameters (parameters , algorithm ):
66109 parameters .pop ('individual_type' , None )
67110 parameters .pop ('initialization_function' , None )
68111 fd , filename = tempfile .mkstemp ()
@@ -75,7 +118,7 @@ def edit_parameters(parameters):
75118 command = f'{ text_editor ()} { filename } '
76119 subprocess .run (command , shell = True , check = True )
77120 params_txt = path .read_text ()
78- new_parameters = parse_parameters (params_txt )
121+ new_parameters = parse_parameters (params_txt , algorithm )
79122 finally :
80123 try :
81124 os .unlink (filename )
@@ -103,7 +146,7 @@ def main():
103146
104147 algorithm = get_algorithm (args .algorithm , seed = args .seed )
105148 params = algorithm .get_parameters ()
106- new_params = edit_parameters (params )
149+ new_params = edit_parameters (params , algorithm . __class__ )
107150 if new_params is None :
108151 print ('Invalid parameters' , file = sys .stderr )
109152 return 1
0 commit comments