Skip to content

Commit 0fd152a

Browse files
committed
Fix CLI
1 parent 2a4a6dc commit 0fd152a

File tree

2 files changed

+57
-14
lines changed

2 files changed

+57
-14
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ For a full list of examples see the [examples folder](examples/).
7979
niaarm -h
8080
usage: niaarm [-h] -i INPUT_FILE [-o OUTPUT_FILE] -a ALGORITHM [-s SEED]
8181
[--max-evals MAX_EVALS] [--max-iters MAX_ITERS] [--alpha ALPHA]
82-
[--beta BETA] [--gamma GAMMA] [--delta DELTA] [--logging]
82+
[--beta BETA] [--gamma GAMMA] [--delta DELTA] [--log]
8383
[--show-stats]
8484
8585
Perform ARM, output mined rules as csv, get mined rules' statistics
@@ -102,7 +102,7 @@ options:
102102
--beta BETA Beta parameter. Default 0
103103
--gamma GAMMA Gamma parameter. Default 0
104104
--delta DELTA Delta parameter. Default 0
105-
--logging Enable logging of fitness improvements
105+
--log Enable logging of fitness improvements
106106
--show-stats Display stats about mined rules
107107
```
108108
Note: The CLI script can also run as a python module (`python -m niaarm ...`)

niaarm/cli.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
from inspect import getmodule, getmembers, isfunction
23
import os
34
from pathlib import Path
45
import platform
@@ -10,6 +11,9 @@
1011
from niaarm import NiaARM, Dataset, Stats
1112
from niapy.task import OptimizationType, Task
1213
from niapy.util.factory import get_algorithm
14+
from niapy.util import distances, repair
15+
from niapy.algorithms.other import mts
16+
from niapy.algorithms.basic import de
1317

1418

1519
def get_parser():
@@ -26,7 +30,7 @@ def get_parser():
2630
parser.add_argument('--beta', type=float, default=0.0, help='Beta parameter. Default 0')
2731
parser.add_argument('--gamma', type=float, default=0.0, help='Gamma parameter. Default 0')
2832
parser.add_argument('--delta', type=float, default=0.0, help='Delta parameter. Default 0')
29-
parser.add_argument('--logging', action='store_true', help='Enable logging of fitness improvements')
33+
parser.add_argument('--log', action='store_true', help='Enable logging of fitness improvements')
3034
parser.add_argument('--show-stats', action='store_true', help='Display stats about mined rules')
3135

3236
return parser
@@ -41,28 +45,67 @@ def parameters_string(parameters):
4145
'# Save and exit to continue\n' \
4246
'# WARNING: Do not edit parameter names\n'
4347
for parameter, value in parameters.items():
44-
params_txt += f'{parameter} = {value}\n'
48+
if isinstance(value, tuple):
49+
if callable(value[0]):
50+
value = tuple(v.__name__ for v in value)
51+
else:
52+
value = tuple(str(v) for v in value)
53+
value = ', '.join(value)
54+
params_txt += f'{parameter} = {value.__name__ if callable(value) else value}\n'
4555
return params_txt
4656

4757

48-
def parse_parameters(text):
58+
def functions(algorithm):
59+
funcs = {}
60+
algorithm_funcs = dict(getmembers(getmodule(algorithm.__class__), isfunction))
61+
repair_funcs = dict(getmembers(repair, isfunction))
62+
distance_funcs = dict(getmembers(distances, isfunction))
63+
de_funcs = dict(getmembers(de, isfunction))
64+
mts_funcs = dict(getmembers(mts, isfunction))
65+
funcs.update(algorithm_funcs)
66+
funcs.update(repair_funcs)
67+
funcs.update(distance_funcs)
68+
funcs.update(de_funcs)
69+
funcs.update(mts_funcs)
70+
return funcs
71+
72+
73+
def find_function(name, algorithm):
74+
return functions(algorithm)[name]
75+
76+
77+
def convert_string(string):
78+
try:
79+
value = float(string)
80+
if value.is_integer():
81+
value = int(value)
82+
except ValueError:
83+
return string
84+
return value
85+
86+
87+
def parse_parameters(text, algorithm):
4988
lines: list[str] = text.strip().split('\n')
5089
lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
5190
parameters = {}
5291
for line in lines:
5392
key, value = line.split('=')
5493
key = key.strip()
55-
try:
56-
value = float(value.strip())
57-
if value.is_integer():
58-
value = int(value)
59-
except ValueError:
60-
pass
94+
value = convert_string(value.strip())
95+
if isinstance(value, str):
96+
if len(value.split(', ')) > 1: # tuple
97+
value = list(map(str.strip, value.split(', ')))
98+
value = tuple(map(convert_string, value))
99+
value = tuple(find_function(v, algorithm) for v in value if type(v) == str)
100+
elif value.lower() == 'true' or value.lower() == 'false': # boolean
101+
value = value.lower() == 'true'
102+
else: # probably a function
103+
value = find_function(value, algorithm)
61104
parameters[key] = value
62105
return parameters
63106

64107

65-
def edit_parameters(parameters):
108+
def edit_parameters(parameters, algorithm):
66109
parameters.pop('individual_type', None)
67110
parameters.pop('initialization_function', None)
68111
fd, filename = tempfile.mkstemp()
@@ -75,7 +118,7 @@ def edit_parameters(parameters):
75118
command = f'{text_editor()} {filename}'
76119
subprocess.run(command, shell=True, check=True)
77120
params_txt = path.read_text()
78-
new_parameters = parse_parameters(params_txt)
121+
new_parameters = parse_parameters(params_txt, algorithm)
79122
finally:
80123
try:
81124
os.unlink(filename)
@@ -103,7 +146,7 @@ def main():
103146

104147
algorithm = get_algorithm(args.algorithm, seed=args.seed)
105148
params = algorithm.get_parameters()
106-
new_params = edit_parameters(params)
149+
new_params = edit_parameters(params, algorithm.__class__)
107150
if new_params is None:
108151
print('Invalid parameters', file=sys.stderr)
109152
return 1

0 commit comments

Comments
 (0)