Skip to content

Commit ef9ea2c

Browse files
authored
Merge pull request #102 from zStupan/main
Added support for config files to CLI
2 parents fc7ae85 + 62c011a commit ef9ea2c

File tree

5 files changed

+371
-185
lines changed

5 files changed

+371
-185
lines changed

docs/cli.rst

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,26 @@ Usage
1414
1515
.. code-block:: text
1616
17-
usage: niaarm [-h] [-v] -i INPUT_FILE [-o OUTPUT_FILE] -a ALGORITHM [-s SEED]
18-
[--max-evals MAX_EVALS] [--max-iters MAX_ITERS] --metrics
19-
METRICS [METRICS ...] [--weights WEIGHTS [WEIGHTS ...]] [--log]
20-
[--show-stats]
17+
usage: niaarm [-h] [-v] [-c CONFIG] [-i INPUT_FILE] [-o OUTPUT_FILE] [--squashing-similarity {euclidean,cosine}] [--squashing-threshold SQUASHING_THRESHOLD] [-a ALGORITHM] [-s SEED] [--max-evals MAX_EVALS] [--max-iters MAX_ITERS]
18+
[--metrics METRICS [METRICS ...]] [--weights WEIGHTS [WEIGHTS ...]] [--log] [--stats]
2119
2220
Perform ARM, output mined rules as csv, get mined rules' statistics
2321
2422
options:
2523
-h, --help show this help message and exit
2624
-v, --version show program's version number and exit
25+
-c CONFIG, --config CONFIG
26+
Path to a TOML config file
2727
-i INPUT_FILE, --input-file INPUT_FILE
2828
Input file containing a csv dataset
2929
-o OUTPUT_FILE, --output-file OUTPUT_FILE
3030
Output file for mined rules
31+
--squashing-similarity {euclidean,cosine}
32+
Similarity measure to use for squashing
33+
--squashing-threshold SQUASHING_THRESHOLD
34+
Threshold to use for squashing
3135
-a ALGORITHM, --algorithm ALGORITHM
32-
Algorithm to use (niapy class name, e.g.
33-
DifferentialEvolution)
36+
Algorithm to use (niapy class name, e.g. DifferentialEvolution)
3437
-s SEED, --seed SEED Seed for the algorithm's random number generator
3538
--max-evals MAX_EVALS
3639
Maximum number of fitness function evaluations
@@ -110,3 +113,58 @@ E.g. (for the above run):
110113
Average length of antecedent: 1.97723292469352
111114
Average length of consequent: 1.5604203152364273
112115
Run Time: 6.4538s
116+
117+
Using a config file
118+
~~~~~~~~~~~~~~~~~~~
119+
120+
Instead of setting all the options as command-line arguments, you can put them in a TOML
121+
file and run:
122+
123+
.. code-block:: shell
124+
125+
niaarm -c config.toml
126+
127+
Bellow is an example of a config file with all the available options:
128+
129+
.. code-block:: toml
130+
131+
# dataset to load
132+
input_file = "datasets/Abalone.csv"
133+
134+
# file to export rules to (optional)
135+
output_file = "output.csv"
136+
137+
# log fitness improvements (optional)
138+
log = true
139+
140+
# print stats of the mined rules (optional)
141+
stats = true
142+
143+
# Data squashing settings (optional)
144+
[preprocessing.squashing]
145+
similarity = "euclid" # or "cosine"
146+
threshold = 0.99
147+
148+
# algorithm settings
149+
[algorithm]
150+
# name of NiaPy class
151+
name = "DifferentialEvolution"
152+
153+
# metrics to compute fitness with
154+
metrics = ["support", "confidence"]
155+
# weights of each metric (optional)
156+
weights = [0.5, 0.5]
157+
158+
# algorithm stopping criteria at least one of max_evals or max_iters is required
159+
max_evals = 10000
160+
max_iters = 1000
161+
162+
# random seed (optional)
163+
seed = 12345
164+
165+
# algorithm parameters (optional), the names need to be the same as NiaPy parameters
166+
[algorithm.parameters]
167+
population_size = 50
168+
differential_weight = 0.5
169+
crossover_probability = 0.9
170+
strategy = "cross_rand1"

examples/config.toml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# dataset to load
2+
input_file = "datasets/Abalone.csv"
3+
4+
# file to export rules to (optional)
5+
output_file = "output.csv"
6+
7+
# log fitness improvements (optional)
8+
log = true
9+
10+
# print stats of the mined rules (optional)
11+
stats = true
12+
13+
# Data squashing settings (optional)
14+
[preprocessing.squashing]
15+
similarity = "euclid" # or "cosine"
16+
threshold = 0.99
17+
18+
# algorithm settings
19+
[algorithm]
20+
# name of NiaPy class
21+
name = "DifferentialEvolution"
22+
23+
# metrics to compute fitness with
24+
metrics = ["support", "confidence"]
25+
# weights of each metric (optional)
26+
weights = [0.5, 0.5]
27+
28+
# algorithm stopping criteria at least one of max_evals or max_iters is required
29+
max_evals = 10000
30+
max_iters = 1000
31+
32+
# random seed (optional)
33+
seed = 12345
34+
35+
# algorithm parameters (optional), the names need to be the same as NiaPy parameters
36+
[algorithm.parameters]
37+
population_size = 50
38+
differential_weight = 0.5
39+
crossover_probability = 0.9
40+
strategy = "cross_rand1"

niaarm/cli.py

Lines changed: 140 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,38 @@
66
import subprocess
77
import sys
88

9+
try:
10+
import tomllib
11+
except ImportError:
12+
import tomli as tomllib
13+
914
import numpy as np
1015
import niaarm
11-
from niaarm import NiaARM, Dataset, get_rules
16+
from niaarm import NiaARM, Dataset, get_rules, squash
1217
from niapy.util.factory import get_algorithm
1318
from niapy.util import distances, repair
1419
from niapy.algorithms.other import mts
1520
from niapy.algorithms.basic import de
1621

22+
DEFAULT_CONFIG = {
23+
"input_file": None,
24+
"output_file": None,
25+
"log": False,
26+
"stats": False,
27+
"preprocessing": {
28+
"squashing": {},
29+
},
30+
"algorithm": {
31+
"name": None,
32+
"max_evals": np.inf,
33+
"max_iters": np.inf,
34+
"metrics": None,
35+
"weights": None,
36+
"seed": None,
37+
"parameters": {},
38+
},
39+
}
40+
1741

1842
def get_parser():
1943
parser = argparse.ArgumentParser(
@@ -26,21 +50,29 @@ def get_parser():
2650
action="version",
2751
version=f"%(prog)s version {niaarm.__version__}",
2852
)
53+
parser.add_argument("-c", "--config", type=str, help="Path to a TOML config file")
2954
parser.add_argument(
3055
"-i",
3156
"--input-file",
3257
type=str,
33-
required=True,
3458
help="Input file containing a csv dataset",
3559
)
3660
parser.add_argument(
3761
"-o", "--output-file", type=str, help="Output file for mined rules"
3862
)
63+
parser.add_argument(
64+
"--squashing-similarity",
65+
type=str,
66+
choices=("euclidean", "cosine"),
67+
help="Similarity measure to use for squashing",
68+
)
69+
parser.add_argument(
70+
"--squashing-threshold", type=float, help="Threshold to use for squashing"
71+
)
3972
parser.add_argument(
4073
"-a",
4174
"--algorithm",
4275
type=str,
43-
required=True,
4476
help="Algorithm to use (niapy class name, e.g. DifferentialEvolution)",
4577
)
4678
parser.add_argument(
@@ -64,7 +96,6 @@ def get_parser():
6496
nargs="+",
6597
action="extend",
6698
choices=NiaARM.available_metrics,
67-
required=True,
6899
metavar="METRICS",
69100
help="Metrics to use in the fitness function.",
70101
)
@@ -85,6 +116,30 @@ def get_parser():
85116
return parser
86117

87118

119+
def deep_update(dictionary, other):
120+
"""Same as `dict.update` but for nested dictionaries."""
121+
updated_dict = dictionary.copy()
122+
for k, v in other.items():
123+
if (
124+
k in updated_dict
125+
and isinstance(updated_dict[k], dict)
126+
and isinstance(v, dict)
127+
):
128+
updated_dict[k] = deep_update(updated_dict[k], v)
129+
else:
130+
updated_dict[k] = v
131+
return updated_dict
132+
133+
134+
def load_config(file):
135+
with open(file, "rb") as f:
136+
return tomllib.load(f)
137+
138+
139+
def validate_config(config):
140+
pass
141+
142+
88143
def text_editor():
89144
return (
90145
os.getenv("VISUAL")
@@ -193,42 +248,108 @@ def main():
193248

194249
if len(sys.argv) == 1:
195250
parser.print_help()
196-
if args.max_evals == np.inf and args.max_iters == np.inf:
197-
print("Error: --max-evals and/or --max-iters missing", file=sys.stderr)
251+
return 0
252+
253+
config = DEFAULT_CONFIG.copy()
254+
if args.config:
255+
try:
256+
config_from_file = load_config(args.config)
257+
config = deep_update(config, config_from_file)
258+
except tomllib.TOMLDecodeError:
259+
print("Error: Invalid config file", file=sys.stderr)
260+
else:
261+
config["input_file"] = args.input_file
262+
config["output_file"] = args.output_file
263+
config["log"] = args.log
264+
config["stats"] = args.stats
265+
config["preprocessing"]["squashing"]["similarity"] = args.squashing_similarity
266+
config["preprocessing"]["squashing"]["threshold"] = args.squashing_threshold
267+
config["algorithm"]["name"] = args.algorithm
268+
config["algorithm"]["seed"] = args.seed
269+
config["algorithm"]["max_evals"] = args.max_evals
270+
config["algorithm"]["max_iters"] = args.max_iters
271+
config["algorithm"]["metrics"] = args.metrics
272+
config["algorithm"]["weights"] = args.weights
273+
274+
if (
275+
config["algorithm"]["max_evals"] == np.inf
276+
and config["algorithm"]["max_iters"] == np.inf
277+
):
278+
print("Error: max_evals or max_iters missing", file=sys.stderr)
198279
return 1
199-
metrics = list(set(args.metrics))
200280

201-
if args.weights and len(args.weights) != len(metrics):
281+
metrics = list(set(config["algorithm"]["metrics"]))
282+
weights = config["algorithm"]["weights"]
283+
284+
if weights and len(weights) != len(metrics):
202285
print(
203-
"Error: There must be the same amount of weights and metrics",
286+
"Error: Metrics and weights dimensions don't match",
204287
file=sys.stderr,
205288
)
206289
return 1
207-
weights = args.weights if args.weights else [1] * len(metrics)
290+
291+
weights = weights if weights else [1] * len(metrics)
208292
metrics = dict(zip(metrics, weights))
209293

210294
try:
211-
dataset = Dataset(args.input_file)
212-
algorithm = get_algorithm(args.algorithm, seed=args.seed)
295+
dataset = Dataset(config["input_file"])
296+
297+
squash_config = config["preprocessing"]["squashing"]
298+
if squash_config and squash_config["similarity"] and squash_config["threshold"]:
299+
num_transactions = len(dataset.transactions)
300+
dataset = squash(
301+
dataset, squash_config["threshold"], squash_config["similarity"]
302+
)
303+
print(
304+
f"Squashed dataset from"
305+
f" {num_transactions} to {len(dataset.transactions)} transactions"
306+
)
307+
308+
algorithm = get_algorithm(
309+
config["algorithm"]["name"], seed=config["algorithm"]["seed"]
310+
)
213311
params = algorithm.get_parameters()
214-
new_params = edit_parameters(params, algorithm.__class__)
312+
if args.config:
313+
new_params = config["algorithm"]["parameters"]
314+
for k, v in new_params.items():
315+
if isinstance(v, str):
316+
if len(v.split(", ")) > 1: # tuple
317+
value = list(map(str.strip, v.split(", ")))
318+
value = tuple(map(convert_string, value))
319+
value = tuple(
320+
find_function(val, algorithm.__class__)
321+
for val in value
322+
if isinstance(v, str)
323+
)
324+
else:
325+
value = find_function(v, algorithm.__class__)
326+
327+
new_params[k] = value
328+
else:
329+
new_params = edit_parameters(params, algorithm.__class__)
330+
215331
if new_params is None:
216-
print("Invalid parameters", file=sys.stderr)
332+
print("Error: Invalid parameters", file=sys.stderr)
217333
return 1
218334
if not set(new_params).issubset(params):
219335
print(
220-
f"Invalid parameters: {set(new_params).difference(params)}",
336+
f"Error: Invalid parameters: {set(new_params).difference(params)}",
221337
file=sys.stderr,
222338
)
223339
return 1
224340

225341
algorithm.set_parameters(**new_params)
226342
rules, run_time = get_rules(
227-
dataset, algorithm, metrics, args.max_evals, args.max_iters, args.log
343+
dataset,
344+
algorithm,
345+
metrics,
346+
config["algorithm"]["max_evals"],
347+
config["algorithm"]["max_iters"],
348+
config["log"],
228349
)
229-
if args.output_file:
230-
rules.to_csv(args.output_file)
231-
if args.stats:
350+
if config["output_file"]:
351+
rules.to_csv(config["output_file"])
352+
if config["stats"]:
232353
print(rules)
233354
print(f"Run Time: {run_time:.4f}s")
234355

0 commit comments

Comments
 (0)