|
| 1 | +import argparse |
| 2 | +from inspect import getmodule, getmembers, isfunction |
| 3 | +import os |
| 4 | +from pathlib import Path |
| 5 | +import platform |
| 6 | +import subprocess |
| 7 | +import sys |
| 8 | +import tempfile |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +from niaarm import NiaARM, Dataset, Stats |
| 12 | +from niapy.task import OptimizationType, Task |
| 13 | +from niapy.util.factory import get_algorithm |
| 14 | +from niapy.util import distances, repair |
| 15 | +from niapy.algorithms.other import mts |
| 16 | +from niapy.algorithms.basic import de |
| 17 | + |
| 18 | + |
| 19 | +def get_parser(): |
| 20 | + parser = argparse.ArgumentParser(prog='niaarm', |
| 21 | + description='Perform ARM, output mined rules as csv, get mined rules\' statistics') |
| 22 | + parser.add_argument('-i', '--input-file', type=str, required=True, help='Input file containing a csv dataset') |
| 23 | + parser.add_argument('-o', '--output-file', type=str, help='Output file for mined rules') |
| 24 | + parser.add_argument('-a', '--algorithm', type=str, required=True, |
| 25 | + help='Algorithm to use (niapy class name, e.g. DifferentialEvolution)') |
| 26 | + parser.add_argument('-s', '--seed', type=int, help='Seed for the algorithm\'s random number generator') |
| 27 | + parser.add_argument('--max-evals', type=int, default=np.inf, help='Maximum number of fitness function evaluations') |
| 28 | + parser.add_argument('--max-iters', type=int, default=np.inf, help='Maximum number of iterations') |
| 29 | + parser.add_argument('--alpha', type=float, default=0.0, help='Alpha parameter. Default 0') |
| 30 | + parser.add_argument('--beta', type=float, default=0.0, help='Beta parameter. Default 0') |
| 31 | + parser.add_argument('--gamma', type=float, default=0.0, help='Gamma parameter. Default 0') |
| 32 | + parser.add_argument('--delta', type=float, default=0.0, help='Delta parameter. Default 0') |
| 33 | + parser.add_argument('--log', action='store_true', help='Enable logging of fitness improvements') |
| 34 | + parser.add_argument('--show-stats', action='store_true', help='Display stats about mined rules') |
| 35 | + |
| 36 | + return parser |
| 37 | + |
| 38 | + |
| 39 | +def text_editor(): |
| 40 | + return os.getenv('VISUAL') or os.getenv('EDITOR') or ('notepad' if platform.system() == 'Windows' else 'vi') |
| 41 | + |
| 42 | + |
| 43 | +def parameters_string(parameters): |
| 44 | + params_txt = '# You can edit the algorithm\'s parameter values here\n' \ |
| 45 | + '# Save and exit to continue\n' \ |
| 46 | + '# WARNING: Do not edit parameter names\n' |
| 47 | + for parameter, value in parameters.items(): |
| 48 | + if isinstance(value, tuple): |
| 49 | + if callable(value[0]): |
| 50 | + value = tuple(v.__name__ for v in value) |
| 51 | + else: |
| 52 | + value = tuple(str(v) for v in value) |
| 53 | + value = ', '.join(value) |
| 54 | + params_txt += f'{parameter} = {value.__name__ if callable(value) else value}\n' |
| 55 | + return params_txt |
| 56 | + |
| 57 | + |
| 58 | +def functions(algorithm): |
| 59 | + funcs = {} |
| 60 | + algorithm_funcs = dict(getmembers(getmodule(algorithm.__class__), isfunction)) |
| 61 | + repair_funcs = dict(getmembers(repair, isfunction)) |
| 62 | + distance_funcs = dict(getmembers(distances, isfunction)) |
| 63 | + de_funcs = dict(getmembers(de, isfunction)) |
| 64 | + mts_funcs = dict(getmembers(mts, isfunction)) |
| 65 | + funcs.update(algorithm_funcs) |
| 66 | + funcs.update(repair_funcs) |
| 67 | + funcs.update(distance_funcs) |
| 68 | + funcs.update(de_funcs) |
| 69 | + funcs.update(mts_funcs) |
| 70 | + return funcs |
| 71 | + |
| 72 | + |
| 73 | +def find_function(name, algorithm): |
| 74 | + return functions(algorithm)[name] |
| 75 | + |
| 76 | + |
| 77 | +def convert_string(string): |
| 78 | + try: |
| 79 | + value = float(string) |
| 80 | + if value.is_integer(): |
| 81 | + value = int(value) |
| 82 | + except ValueError: |
| 83 | + return string |
| 84 | + return value |
| 85 | + |
| 86 | + |
| 87 | +def parse_parameters(text, algorithm): |
| 88 | + lines: list[str] = text.strip().split('\n') |
| 89 | + lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')] |
| 90 | + parameters = {} |
| 91 | + for line in lines: |
| 92 | + key, value = line.split('=') |
| 93 | + key = key.strip() |
| 94 | + value = convert_string(value.strip()) |
| 95 | + if isinstance(value, str): |
| 96 | + if len(value.split(', ')) > 1: # tuple |
| 97 | + value = list(map(str.strip, value.split(', '))) |
| 98 | + value = tuple(map(convert_string, value)) |
| 99 | + value = tuple(find_function(v, algorithm) for v in value if type(v) == str) |
| 100 | + elif value.lower() == 'true' or value.lower() == 'false': # boolean |
| 101 | + value = value.lower() == 'true' |
| 102 | + else: # probably a function |
| 103 | + value = find_function(value, algorithm) |
| 104 | + parameters[key] = value |
| 105 | + return parameters |
| 106 | + |
| 107 | + |
| 108 | +def edit_parameters(parameters, algorithm): |
| 109 | + parameters.pop('individual_type', None) |
| 110 | + parameters.pop('initialization_function', None) |
| 111 | + fd, filename = tempfile.mkstemp() |
| 112 | + os.close(fd) |
| 113 | + |
| 114 | + new_parameters = None |
| 115 | + try: |
| 116 | + path = Path(filename) |
| 117 | + path.write_text(parameters_string(parameters)) |
| 118 | + command = f'{text_editor()} {filename}' |
| 119 | + subprocess.run(command, shell=True, check=True) |
| 120 | + params_txt = path.read_text() |
| 121 | + new_parameters = parse_parameters(params_txt, algorithm) |
| 122 | + finally: |
| 123 | + try: |
| 124 | + os.unlink(filename) |
| 125 | + except Exception as e: |
| 126 | + print('Error:', e, file=sys.stderr) |
| 127 | + return new_parameters |
| 128 | + |
| 129 | + |
| 130 | +def main(): |
| 131 | + parser = get_parser() |
| 132 | + args = parser.parse_args() |
| 133 | + |
| 134 | + if len(sys.argv) == 1: |
| 135 | + parser.print_help() |
| 136 | + if args.max_evals == np.inf and args.max_iters == np.inf: |
| 137 | + print('--max-evals and/or --max-iters missing', file=sys.stderr) |
| 138 | + return 1 |
| 139 | + |
| 140 | + try: |
| 141 | + dataset = Dataset(args.input_file) |
| 142 | + problem = NiaARM(dataset.dimension, dataset.features, dataset.transactions, args.alpha, args.beta, args.gamma, |
| 143 | + args.delta, args.log) |
| 144 | + task = Task(problem, max_iters=args.max_iters, max_evals=args.max_evals, |
| 145 | + optimization_type=OptimizationType.MAXIMIZATION) |
| 146 | + |
| 147 | + algorithm = get_algorithm(args.algorithm, seed=args.seed) |
| 148 | + params = algorithm.get_parameters() |
| 149 | + new_params = edit_parameters(params, algorithm.__class__) |
| 150 | + if new_params is None: |
| 151 | + print('Invalid parameters', file=sys.stderr) |
| 152 | + return 1 |
| 153 | + |
| 154 | + for param in new_params: |
| 155 | + if param not in params: |
| 156 | + print(f'Invalid parameter: {param}', file=sys.stderr) |
| 157 | + return 1 |
| 158 | + |
| 159 | + algorithm.set_parameters(**new_params) |
| 160 | + |
| 161 | + algorithm.run(task) |
| 162 | + |
| 163 | + if args.output_file: |
| 164 | + problem.sort_rules() |
| 165 | + problem.export_rules(args.output_file) |
| 166 | + |
| 167 | + if args.show_stats: |
| 168 | + stats = Stats(problem.rules) |
| 169 | + print('\nSTATS:') |
| 170 | + print(f'Total rules: {stats.total_rules}') |
| 171 | + print(f'Average fitness: {stats.avg_fitness}') |
| 172 | + print(f'Average support: {stats.avg_support}') |
| 173 | + print(f'Average confidence: {stats.avg_confidence}') |
| 174 | + print(f'Average coverage: {stats.avg_coverage}') |
| 175 | + print(f'Average shrinkage: {stats.avg_shrinkage}') |
| 176 | + print(f'Average length of antecedent: {stats.avg_ant_len}') |
| 177 | + print(f'Average length of consequent: {stats.avg_con_len}') |
| 178 | + |
| 179 | + except Exception as e: |
| 180 | + print('Error:', e, file=sys.stderr) |
| 181 | + return 1 |
0 commit comments