Skip to content

Commit 1314882

Browse files
committed
Add -i, --input_file optional flag for overriding conf.yaml location
1 parent d5bbebb commit 1314882

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

examples/mpi_learn.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
11
import plasma.global_vars as g
22
g.init_MPI()
3+
4+
import os.path
5+
6+
7+
# TODO(KGF): replace this workaround for the "from plasma.conf import conf"
8+
def is_valid_file(parser, arg):
9+
if not (os.path.exists(arg) and os.path.isfile(arg)):
10+
parser.error("The file %s does not exist!" % arg)
11+
else:
12+
return arg
13+
14+
import argparse
15+
parser = argparse.ArgumentParser(prog='mpi_learn', description='FusionDL TensorFlow 1.x + mpi4py')
16+
parser.add_argument("--input_file", "-i", # type=str,
17+
required=False, dest="conf_file",
18+
help="input YAML file for configuration", metavar="YAML_FILE",
19+
type=lambda x: is_valid_file(parser, x))
20+
args = parser.parse_args()
21+
22+
g.conf_file = args.conf_file
23+
24+
25+
from plasma.conf import conf
26+
327
from plasma.models.mpi_runner import (
428
mpi_train, mpi_make_predictions_and_evaluate
529
)
30+
631
from plasma.preprocessor.preprocess import guarantee_preprocessed
732
from plasma.models.loader import Loader
8-
from plasma.conf import conf
33+
934
'''
1035
#########################################################
1136
This file trains a deep learning model to predict

plasma/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from plasma.conf_parser import parameters
22
import os
33
import errno
4+
import plasma.global_vars as g
5+
46

57
# TODO(KGF): this conf.py feels like an unnecessary level of indirection
6-
if os.path.exists(os.path.join(os.path.abspath(os.path.dirname(__file__)),
8+
if g.conf_file is not None:
9+
g.print_unique(f"Loading configuration from {g.conf_file}")
10+
conf = parameters(g.conf_file)
11+
elif os.path.exists(os.path.join(os.path.abspath(os.path.dirname(__file__)),
712
'../examples/conf.yaml')):
813
conf = parameters(os.path.join(os.path.abspath(os.path.dirname(__file__)),
914
'../examples/conf.yaml'))

plasma/global_vars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# TODO(KGF): remove this (and all?) references to Keras backend
1111
backend = ''
1212
tf_ver = None
13+
conf_file = None
1314

1415

1516
def init_MPI():

0 commit comments

Comments
 (0)