66import subprocess
77import sys
88
9+ try :
10+ import tomllib
11+ except ImportError :
12+ import tomli as tomllib
13+
914import numpy as np
1015import niaarm
11- from niaarm import NiaARM , Dataset , get_rules
16+ from niaarm import NiaARM , Dataset , get_rules , squash
1217from niapy .util .factory import get_algorithm
1318from niapy .util import distances , repair
1419from niapy .algorithms .other import mts
1520from niapy .algorithms .basic import de
1621
22+ DEFAULT_CONFIG = {
23+ "input_file" : None ,
24+ "output_file" : None ,
25+ "log" : False ,
26+ "stats" : False ,
27+ "preprocessing" : {
28+ "squashing" : {},
29+ },
30+ "algorithm" : {
31+ "name" : None ,
32+ "max_evals" : np .inf ,
33+ "max_iters" : np .inf ,
34+ "metrics" : None ,
35+ "weights" : None ,
36+ "seed" : None ,
37+ "parameters" : {},
38+ },
39+ }
40+
1741
1842def get_parser ():
1943 parser = argparse .ArgumentParser (
@@ -26,21 +50,29 @@ def get_parser():
2650 action = "version" ,
2751 version = f"%(prog)s version { niaarm .__version__ } " ,
2852 )
53+ parser .add_argument ("-c" , "--config" , type = str , help = "Path to a TOML config file" )
2954 parser .add_argument (
3055 "-i" ,
3156 "--input-file" ,
3257 type = str ,
33- required = True ,
3458 help = "Input file containing a csv dataset" ,
3559 )
3660 parser .add_argument (
3761 "-o" , "--output-file" , type = str , help = "Output file for mined rules"
3862 )
63+ parser .add_argument (
64+ "--squashing-similarity" ,
65+ type = str ,
66+ choices = ("euclidean" , "cosine" ),
67+ help = "Similarity measure to use for squashing" ,
68+ )
69+ parser .add_argument (
70+ "--squashing-threshold" , type = float , help = "Threshold to use for squashing"
71+ )
3972 parser .add_argument (
4073 "-a" ,
4174 "--algorithm" ,
4275 type = str ,
43- required = True ,
4476 help = "Algorithm to use (niapy class name, e.g. DifferentialEvolution)" ,
4577 )
4678 parser .add_argument (
@@ -64,7 +96,6 @@ def get_parser():
6496 nargs = "+" ,
6597 action = "extend" ,
6698 choices = NiaARM .available_metrics ,
67- required = True ,
6899 metavar = "METRICS" ,
69100 help = "Metrics to use in the fitness function." ,
70101 )
@@ -85,6 +116,30 @@ def get_parser():
85116 return parser
86117
87118
119+ def deep_update (dictionary , other ):
120+ """Same as `dict.update` but for nested dictionaries."""
121+ updated_dict = dictionary .copy ()
122+ for k , v in other .items ():
123+ if (
124+ k in updated_dict
125+ and isinstance (updated_dict [k ], dict )
126+ and isinstance (v , dict )
127+ ):
128+ updated_dict [k ] = deep_update (updated_dict [k ], v )
129+ else :
130+ updated_dict [k ] = v
131+ return updated_dict
132+
133+
134+ def load_config (file ):
135+ with open (file , "rb" ) as f :
136+ return tomllib .load (f )
137+
138+
139+ def validate_config (config ):
140+ pass
141+
142+
88143def text_editor ():
89144 return (
90145 os .getenv ("VISUAL" )
@@ -193,42 +248,108 @@ def main():
193248
194249 if len (sys .argv ) == 1 :
195250 parser .print_help ()
196- if args .max_evals == np .inf and args .max_iters == np .inf :
197- print ("Error: --max-evals and/or --max-iters missing" , file = sys .stderr )
251+ return 0
252+
253+ config = DEFAULT_CONFIG .copy ()
254+ if args .config :
255+ try :
256+ config_from_file = load_config (args .config )
257+ config = deep_update (config , config_from_file )
258+ except tomllib .TOMLDecodeError :
259+ print ("Error: Invalid config file" , file = sys .stderr )
260+ else :
261+ config ["input_file" ] = args .input_file
262+ config ["output_file" ] = args .output_file
263+ config ["log" ] = args .log
264+ config ["stats" ] = args .stats
265+ config ["preprocessing" ]["squashing" ]["similarity" ] = args .squashing_similarity
266+ config ["preprocessing" ]["squashing" ]["threshold" ] = args .squashing_threshold
267+ config ["algorithm" ]["name" ] = args .algorithm
268+ config ["algorithm" ]["seed" ] = args .seed
269+ config ["algorithm" ]["max_evals" ] = args .max_evals
270+ config ["algorithm" ]["max_iters" ] = args .max_iters
271+ config ["algorithm" ]["metrics" ] = args .metrics
272+ config ["algorithm" ]["weights" ] = args .weights
273+
274+ if (
275+ config ["algorithm" ]["max_evals" ] == np .inf
276+ and config ["algorithm" ]["max_iters" ] == np .inf
277+ ):
278+ print ("Error: max_evals or max_iters missing" , file = sys .stderr )
198279 return 1
199- metrics = list (set (args .metrics ))
200280
201- if args .weights and len (args .weights ) != len (metrics ):
281+ metrics = list (set (config ["algorithm" ]["metrics" ]))
282+ weights = config ["algorithm" ]["weights" ]
283+
284+ if weights and len (weights ) != len (metrics ):
202285 print (
203- "Error: There must be the same amount of weights and metrics " ,
286+ "Error: Metrics and weights dimensions don't match " ,
204287 file = sys .stderr ,
205288 )
206289 return 1
207- weights = args .weights if args .weights else [1 ] * len (metrics )
290+
291+ weights = weights if weights else [1 ] * len (metrics )
208292 metrics = dict (zip (metrics , weights ))
209293
210294 try :
211- dataset = Dataset (args .input_file )
212- algorithm = get_algorithm (args .algorithm , seed = args .seed )
295+ dataset = Dataset (config ["input_file" ])
296+
297+ squash_config = config ["preprocessing" ]["squashing" ]
298+ if squash_config and squash_config ["similarity" ] and squash_config ["threshold" ]:
299+ num_transactions = len (dataset .transactions )
300+ dataset = squash (
301+ dataset , squash_config ["threshold" ], squash_config ["similarity" ]
302+ )
303+ print (
304+ f"Squashed dataset from"
305+ f" { num_transactions } to { len (dataset .transactions )} transactions"
306+ )
307+
308+ algorithm = get_algorithm (
309+ config ["algorithm" ]["name" ], seed = config ["algorithm" ]["seed" ]
310+ )
213311 params = algorithm .get_parameters ()
214- new_params = edit_parameters (params , algorithm .__class__ )
312+ if args .config :
313+ new_params = config ["algorithm" ]["parameters" ]
314+ for k , v in new_params .items ():
315+ if isinstance (v , str ):
316+ if len (v .split (", " )) > 1 : # tuple
317+ value = list (map (str .strip , v .split (", " )))
318+ value = tuple (map (convert_string , value ))
319+ value = tuple (
320+ find_function (val , algorithm .__class__ )
321+ for val in value
322+ if isinstance (v , str )
323+ )
324+ else :
325+ value = find_function (v , algorithm .__class__ )
326+
327+ new_params [k ] = value
328+ else :
329+ new_params = edit_parameters (params , algorithm .__class__ )
330+
215331 if new_params is None :
216- print ("Invalid parameters" , file = sys .stderr )
332+ print ("Error: Invalid parameters" , file = sys .stderr )
217333 return 1
218334 if not set (new_params ).issubset (params ):
219335 print (
220- f"Invalid parameters: { set (new_params ).difference (params )} " ,
336+ f"Error: Invalid parameters: { set (new_params ).difference (params )} " ,
221337 file = sys .stderr ,
222338 )
223339 return 1
224340
225341 algorithm .set_parameters (** new_params )
226342 rules , run_time = get_rules (
227- dataset , algorithm , metrics , args .max_evals , args .max_iters , args .log
343+ dataset ,
344+ algorithm ,
345+ metrics ,
346+ config ["algorithm" ]["max_evals" ],
347+ config ["algorithm" ]["max_iters" ],
348+ config ["log" ],
228349 )
229- if args . output_file :
230- rules .to_csv (args . output_file )
231- if args . stats :
350+ if config [ " output_file" ] :
351+ rules .to_csv (config [ " output_file" ] )
352+ if config [ " stats" ] :
232353 print (rules )
233354 print (f"Run Time: { run_time :.4f} s" )
234355
0 commit comments