Skip to content

Commit e53de3d

Browse files
authored
Improve entrypoint usage (#31)
* Switch to custom Logger * Plug run.py into entrypoint * Fix or disable tests
1 parent 1146cc9 commit e53de3d

File tree

16 files changed

+140
-142
lines changed

16 files changed

+140
-142
lines changed

src/o2tuner/__init__.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
"""
2-
o2tuner module
2+
o2tuner main module
33
"""
44

55
import sys
66
from pkg_resources import require
77
from o2tuner.argumentparser import O2TunerArgumentParser
8-
from o2tuner.tuner import O2Tuner, O2TunerError
9-
from o2tuner.backends import OptunaHandler
8+
from o2tuner.tuner import O2TunerError
109
from o2tuner.log import Log
11-
12-
13-
def objective(trial):
14-
x_var = trial.suggest_float("x", -10, 10)
15-
return (x_var-2)**2
10+
from o2tuner.run import run
1611

1712

1813
LOG = Log()
1914

2015

2116
def entrypoint():
2217
arg_parser = O2TunerArgumentParser()
23-
arg_parser.gen_config_help(O2Tuner.get_default_conf())
18+
# arg_parser.gen_config_help(O2Tuner.get_default_conf())
2419
args = arg_parser.parse_args()
2520

2621
LOG.set_quiet(args.quiet)
@@ -40,20 +35,12 @@ def process_actions(args):
4035
print(f"{__package__} {ver}")
4136
return
4237

43-
optuna_handler = OptunaHandler()
44-
optuna_handler.set_objective(objective)
45-
46-
# Create and run the tuner
47-
tuner = O2Tuner(optuna_handler)
48-
4938
if args.action in ["run"]:
50-
process_run(tuner, args)
39+
process_run(args)
5140
else:
5241
assert False, "invalid action"
5342

5443

55-
def process_run(o2_tuner, args):
44+
def process_run(args):
5645
if args.action == "run":
57-
LOG.info("Running ...")
58-
o2_tuner.init(n_trials=50)
59-
o2_tuner.run()
46+
run(args)

src/o2tuner/argumentparser.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,7 @@
33
"""
44

55

6-
import os.path
76
import argparse
8-
from collections import namedtuple
9-
import yaml
10-
11-
O2TunerArg = namedtuple("O2TunerArg", "option config descr")
127

138

149
class O2TunerArgumentParser(argparse.ArgumentParser):
@@ -19,35 +14,39 @@ class O2TunerArgumentParser(argparse.ArgumentParser):
1914
"""
2015

2116
def __init__(self):
22-
self.args_normal = []
17+
# self.args_normal = []
2318
super().__init__(formatter_class=argparse.RawTextHelpFormatter)
2419
super().add_argument("-v", "--version", dest="version", default=False, action="store_true",
25-
help="Print current alidock version on stdout")
20+
help="Print current o2tuner version on stdout")
21+
super().add_argument("-w", "--workdir", dest="work_dir", help="Working directory to run in",
22+
required=True)
23+
super().add_argument("-c", "--config", help="your configuration", required=True)
2624
super().add_argument("-q", "--quiet", dest="quiet", default=False,
2725
action="store_true", help="Do not print any message")
2826
super().add_argument("-d", "--debug", dest="debug", default=None,
29-
action="store_true", help="Increase verbosity")
27+
action="store_true", help="Increase verbosity level")
28+
super().add_argument("-s", "--stages", nargs="*", help="Run until specified stage")
3029
super().add_argument("action", default="run",
31-
nargs="?", choices=["run"], help="Run optimiser")
32-
33-
def gen_config_help(self, default_conf):
34-
conf_file = os.path.join(os.getcwd(), ".o2tuner-config.yaml")
35-
epilog = f"It is possible to specify the most frequently used options in a YAML " \
36-
f"configuration file in your working directory\n" \
37-
f"Current expected path is: {conf_file}\n" \
38-
f"The following options (along with their default values) can be specified " \
39-
f"(please include `---` as first line):\n---\n"
40-
yaml_lines = {}
41-
longest = 0
42-
for opt in self.args_normal:
43-
if opt.config:
44-
assert opt.config in default_conf, f"option {opt.config} expected in default conf"
45-
optd = {opt.config: default_conf[opt.config]}
46-
yaml_lines[opt.option] = yaml.dump(
47-
optd, default_flow_style=False).rstrip()
48-
longest = max(longest, len(yaml_lines[opt.option]))
49-
fmt = f"%%-{longest}s # same as option %%s\n"
50-
for y_line in yaml_lines.items():
51-
epilog += fmt % (yaml_lines[y_line], y_line)
52-
53-
self.epilog = epilog
30+
nargs="?", choices=["run", "init"], help="Actions to be performed")
31+
32+
# def gen_config_help(self, default_conf):
33+
# conf_file = os.path.join(os.getcwd(), ".o2tuner-config.yaml")
34+
# epilog = f"It is possible to specify the most frequently used options in a YAML " \
35+
# f"configuration file in your working directory\n" \
36+
# f"Current expected path is: {conf_file}\n" \
37+
# f"The following options (along with their default values) can be specified " \
38+
# f"(please include `---` as first line):\n---\n"
39+
# yaml_lines = {}
40+
# longest = 0
41+
# for opt in self.args_normal:
42+
# if opt.config:
43+
# assert opt.config in default_conf, f"option {opt.config} expected in default conf"
44+
# optd = {opt.config: default_conf[opt.config]}
45+
# yaml_lines[opt.option] = yaml.dump(
46+
# optd, default_flow_style=False).rstrip()
47+
# longest = max(longest, len(yaml_lines[opt.option]))
48+
# fmt = f"%%-{longest}s # same as option %%s\n"
49+
# for y_line in yaml_lines.items():
50+
# epilog += fmt % (yaml_lines[y_line], y_line)
51+
52+
# self.epilog = epilog

src/o2tuner/backends.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
from o2tuner.io import make_dir, exists_file
1414
from o2tuner.utils import annotate_trial
15+
from o2tuner.log import Log
16+
17+
LOG = Log()
1518

1619

1720
def make_trial_directory(trial):
@@ -20,10 +23,10 @@ def make_trial_directory(trial):
2023
"""
2124
user_attributes = trial.user_attrs
2225
if "cwd" in user_attributes:
23-
print(f"ERROR: This trial has already a directory attached: {user_attributes['cwd']}")
26+
LOG.error(f"This trial has already a directory attached: {user_attributes['cwd']}")
2427
sys.exit(1)
2528
if "cwd" not in trial.study.user_attrs:
26-
print("ERROR: This optimisation was not configured to run inside a directory. Please define a working directory.")
29+
LOG.error("This optimisation was not configured to run inside a directory. Please define a working directory.")
2730
sys.exit(1)
2831
top_dir = trial.study.user_attrs["cwd"]
2932
timestamp = str(int(time() * 1000000))
@@ -44,14 +47,14 @@ def load_or_create_study(study_name=None, storage=None, sampler=None, workdir=No
4447
# we force a name to be given by the user for those cases
4548
try:
4649
study = optuna.load_study(study_name=study_name, storage=storage, sampler=sampler)
47-
print(f"Loading existing study {study_name} from storage {storage}")
50+
LOG.info(f"Loading existing study {study_name} from storage {storage}")
4851
except KeyError:
4952
study = optuna.create_study(study_name=study_name, storage=storage, sampler=sampler)
50-
print(f"Creating new study {study_name} at storage {storage}")
53+
LOG.info(f"Creating new study {study_name} at storage {storage}")
5154
except ImportError as exc:
5255
# Probably cannot import MySQL stuff
53-
print("Probably cannot import what is needed for database access. Will try to attempt a serial run.")
54-
print(exc)
56+
LOG.info("Probably cannot import what is needed for database access. Will try to attempt a serial run.")
57+
LOG.info(exc)
5558
else:
5659
return True, study
5760
# This is a "one-time" in-memory study so we don't care so much for the name honestly, could be None
@@ -61,7 +64,7 @@ def load_or_create_study(study_name=None, storage=None, sampler=None, workdir=No
6164
file_name = join(workdir, f"{study_name}.pkl")
6265
if exists_file(file_name):
6366
with open(file_name, "rb") as save_file:
64-
print(f"Loading existing study {study_name} from file {file_name}")
67+
LOG.info(f"Loading existing study {study_name} from file {file_name}")
6568
return False, pickle.load(save_file)
6669

6770
return False, optuna.create_study(study_name=study_name, sampler=sampler)
@@ -130,7 +133,7 @@ def initialise(self, n_trials=100):
130133

131134
def optimise(self):
132135
if not self._n_trials or not self._objective:
133-
print("ERROR: Not initialised: Number of trials and objective function need to be set")
136+
LOG.error("Not initialised: Number of trials and objective function need to be set")
134137
return
135138
self._study.optimize(self.objective_wrapper, n_trials=self._n_trials)
136139

@@ -145,7 +148,7 @@ def set_objective(self, objective):
145148
if hasattr(objective, "needs_cwd"):
146149
self._needs_cwd_per_trial = True
147150
if n_params > 2 or not n_params:
148-
print("Invalid signature of objective funtion. Need either 1 argument (only trial object) or 2 arguments (trial object and user_config)")
151+
LOG.error("Invalid signature of objective funtion. Need either 1 argument (only trial obj) or 2 arguments (trial object + user_config)")
149152
sys.exit(1)
150153
if n_params == 1:
151154
self._objective = objective

src/o2tuner/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""
2-
Configuration functinality and parsing
2+
Configuration functionality and parsing
33
"""
4+
import dataclasses
45
from os.path import join, basename, abspath
56
from glob import glob
67

78
from o2tuner.io import make_dir, parse_yaml, exists_dir
89

910

10-
class WorkDir: # pylint: disable=too-few-public-methods
11+
@dataclasses.dataclass
12+
class WorkDir:
1113
"""
1214
Use this object to set the working directory globally
1315
"""

src/o2tuner/graph.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""
44

55
import sys
6+
from o2tuner.log import Log
7+
8+
LOG = Log()
69

710

811
class GraphDAG: # pylint: disable=too-few-public-methods
@@ -45,7 +48,7 @@ def __init__(self, n_nodes, edges):
4548
self.graph[origin][target] = True
4649

4750
if origin > n_nodes or target > n_nodes or origin < 0 or target < 0:
48-
print(f"ERROR: Found edge ({origin}, {target}) but nodes must be >= 0 and < {n_nodes}")
51+
LOG.error(f"Found edge ({origin}, {target}) but nodes must be >= 0 and < {n_nodes}")
4952
sys.exit(1)
5053
self.from_nodes[target].append(origin)
5154
self.to_nodes[origin].append(target)
@@ -66,14 +69,14 @@ def make_topology(self):
6669
in_degree = self.in_degree.copy()
6770
queue = [i for i, v in enumerate(in_degree) if not v]
6871
if not queue:
69-
print("ERROR: There is no source node in the topology")
72+
LOG.error("There is no source node in the topology")
7073
return False
7174

7275
counter = 0
7376
while queue:
7477
current = queue.pop(0)
7578
if current >= self.n_nodes or current < 0:
76-
print(f"ERROR: Found an edge which node {current} but nodes are only valid from 0 to {self.n_nodes - 1}.")
79+
LOG.error(f"Found an edge which node {current} but nodes are only valid from 0 to {self.n_nodes - 1}.")
7780
return False
7881
self.topology.append(current)
7982
for target in self.to_nodes[current]:
@@ -82,7 +85,7 @@ def make_topology(self):
8285
queue.append(target)
8386
counter += 1
8487
if counter != self.n_nodes:
85-
print("ERROR: There is at least one cyclic dependency.")
88+
LOG.error("There is at least one cyclic dependency.")
8689
return False
8790
return True
8891

src/o2tuner/inspector.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from o2tuner.io import parse_yaml
2020
from o2tuner.backends import load_or_create_study
2121
from o2tuner.sampler import construct_sampler
22+
from o2tuner.log import Log
23+
24+
LOG = Log()
2225

2326

2427
class O2TunerInspector:
@@ -39,7 +42,7 @@ def load(self, opt_config=None, opt_work_dir=None, opt_user_config=None):
3942
Loading wrapper
4043
"""
4144
if not opt_config and not opt_work_dir:
42-
print("WARNING: Nothing to load, no configuration given")
45+
LOG.warning("Nothing to load, no configuration given")
4346
return False
4447
if isinstance(opt_config, str):
4548
opt_config = parse_yaml(opt_config)
@@ -63,7 +66,7 @@ def get_annotation_per_trial(self, key, accept_missing_annotation=True):
6366
for trial in self._study.trials:
6467
user_attrs = trial.user_attrs
6568
if key not in user_attrs:
66-
print(f"ERROR: Key {key} not in trial number {trial.number}.")
69+
LOG.error(f"Key {key} not in trial number {trial.number}.")
6770
sys.exit(1)
6871
ret_list.append(user_attrs[key])
6972
return ret_list
@@ -89,7 +92,7 @@ def plot_importance(self, *, n_most_important=50, map_params=None):
8992

9093
However, add some functionality we would like to have here
9194
"""
92-
print("Plotting importance")
95+
LOG.info("Plotting importance")
9396
param_names, importance_values = self.get_most_important(n_most_important)
9497

9598
if map_params:
@@ -108,7 +111,7 @@ def plot_parallel_coordinates(self, *, n_most_important=20, map_params=None):
108111
"""
109112
Plot parallel coordinates. Each horizontal line represents a trial, each vertical line a parameter
110113
"""
111-
print("Plotting parallel coordinates")
114+
LOG.info("Plotting parallel coordinates")
112115
params, _ = self.get_most_important(n_most_important)
113116

114117
curves = [[] for _ in self._study.trials]
@@ -159,7 +162,7 @@ def plot_parallel_coordinates(self, *, n_most_important=20, map_params=None):
159162
return figure, axes
160163

161164
def plot_slices(self, *, n_most_important=21, map_params=None):
162-
print("Plotting slices")
165+
LOG.info("Plotting slices")
163166
params, _ = self.get_most_important(n_most_important)
164167

165168
n_rows = ceil(sqrt(len(params)))
@@ -204,7 +207,7 @@ def plot_correlations(self, *, n_most_important=20, map_params=None):
204207
"""
205208
Plot correlation among parameters
206209
"""
207-
print("Plotting parameter correlations")
210+
LOG.info("Plotting parameter correlations")
208211
params, _ = self.get_most_important(n_most_important)
209212
params_labels = params
210213
if map_params:
@@ -247,7 +250,7 @@ def plot_pairwise_scatter(self, *, n_most_important=20, map_params=None):
247250
"""
248251
Plot correlation among parameters
249252
"""
250-
print("Plotting pair-wise scatter")
253+
LOG.info("Plotting pair-wise scatter")
251254
params, _ = self.get_most_important(n_most_important)
252255
params_labels = params
253256
if map_params:

src/o2tuner/io.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import json
1010
import yaml
1111

12+
from o2tuner.log import Log
13+
14+
LOG = Log()
15+
1216

13-
############################
14-
# STANDARD FILE SYSTEM I/O #
15-
############################
1617
def exists_file(path):
1718
"""wrapper around python's os.path.isfile
1819

@@ -34,11 +35,11 @@ def make_dir(path):
3435
"""
3536
if exists(path):
3637
if exists_file(path):
37-
# if exists and if that is actually a file instead of a directory, fail here...
38-
print(f"Attempted to create directory {path}. However, a file seems to exist there, quitting")
38+
# if it exists and if that is actually a file instead of a directory, fail here...
39+
LOG.error(f"Attempted to create directory {path}. However, a file seems to exist there, quitting")
3940
sys.exit(1)
4041
# ...otherwise just warn.
41-
print(f"WARNING: The directory {path} already exists, not overwriting")
42+
LOG.warning(f"The directory {path} already exists, not overwriting")
4243
return
4344
# make the whole path structure
4445
makedirs(path)
@@ -105,7 +106,7 @@ def parse_yaml(path):
105106
with open(path, encoding="utf8") as in_file:
106107
return yaml.safe_load(in_file)
107108
except (OSError, IOError, yaml.YAMLError) as exc:
108-
print(f"ERROR: Cannot parse YAML from {path} due to\n{exc}")
109+
LOG.error(f"ERROR: Cannot parse YAML from {path} due to\n{exc}")
109110
sys.exit(1)
110111

111112

@@ -121,7 +122,7 @@ def dump_yaml(to_yaml, path, *, no_refs=False):
121122
else:
122123
yaml.safe_dump(to_yaml, out_file)
123124
except (OSError, IOError, yaml.YAMLError) as eexc:
124-
print(f"ERROR: Cannot write YAML to {path} due to\n{eexc}")
125+
LOG.error(f"ERROR: Cannot write YAML to {path} due to\n{eexc}")
125126
sys.exit(1)
126127

127128

@@ -134,7 +135,7 @@ def parse_json(filepath):
134135
"""
135136
filepath = expanduser(filepath)
136137
if not exists_file(filepath):
137-
print(f"ERROR: JSON file {filepath} does not exist.")
138+
LOG.error(f"ERROR: JSON file {filepath} does not exist.")
138139
sys.exit(1)
139140
with open(filepath, "r", encoding="utf8") as config_file:
140141
return json.load(config_file)

0 commit comments

Comments
 (0)