Skip to content

Commit 7a22907

Browse files
committed
fix start from command
1 parent 0e4931e commit 7a22907

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

tools/infer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def parse_args():
5151
parser = argparse.ArgumentParser(description='paddle-rec run')
5252
parser.add_argument("-m", "--config_yaml", type=str)
53-
parser.add_argument("--device", type=str)
53+
parser.add_argument("-o", "--opt", nargs='*', type=str)
5454
args = parser.parse_args()
5555
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
5656
args.config_yaml = get_abs_model(args.config_yaml)
@@ -63,14 +63,15 @@ def main(args):
6363
config = load_yaml(args.config_yaml)
6464
dy_model_class = load_dy_model_class(args.abs_dir)
6565
config["config_abs_dir"] = args.abs_dir
66-
# tools.vars
67-
if args.device is None:
68-
use_gpu = config.get("runner.use_gpu", True)
69-
elif args.device == "gpu":
70-
use_gpu = True
71-
else:
72-
use_gpu = False
66+
# modify config from command
67+
if args.opt:
68+
for parameter in args.opt:
69+
parameter = parameter.strip()
70+
key, value = parameter.split("=")
71+
config[key] = value
7372

73+
# tools.vars
74+
use_gpu = config.get("runner.use_gpu", True)
7475
use_visual = config.get("runner.use_visual", False)
7576
test_data_dir = config.get("runner.test_data_dir", None)
7677
print_interval = config.get("runner.print_interval", None)

tools/static_infer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
def parse_args():
3636
parser = argparse.ArgumentParser("PaddleRec train static script")
3737
parser.add_argument("-m", "--config_yaml", type=str)
38-
parser.add_argument("--device", type=str)
38+
parser.add_argument("-o", "--opt", nargs='*', type=str)
3939
args = parser.parse_args()
4040
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
4141
args.config_yaml = get_abs_model(args.config_yaml)
@@ -48,6 +48,12 @@ def main(args):
4848
# load config
4949
config = load_yaml(args.config_yaml)
5050
config["config_abs_dir"] = args.abs_dir
51+
# modify config from command
52+
if args.opt:
53+
for parameter in args.opt:
54+
parameter = parameter.strip()
55+
key, value = parameter.split("=")
56+
config[key] = value
5157
# load static model class
5258
static_model_class = load_static_model_class(config)
5359

tools/static_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
def parse_args():
3737
parser = argparse.ArgumentParser("PaddleRec train static script")
3838
parser.add_argument("-m", "--config_yaml", type=str)
39-
parser.add_argument("--device", type=str)
39+
parser.add_argument("-o", "--opt", nargs='*', type=str)
4040
args = parser.parse_args()
4141
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
4242
args.config_yaml = get_abs_model(args.config_yaml)
@@ -50,6 +50,12 @@ def main(args):
5050
config = load_yaml(args.config_yaml)
5151
config["yaml_path"] = args.config_yaml
5252
config["config_abs_dir"] = args.abs_dir
53+
# modify config from command
54+
if args.opt:
55+
for parameter in args.opt:
56+
parameter = parameter.strip()
57+
key, value = parameter.split("=")
58+
config[key] = value
5359
# load static model class
5460
static_model_class = load_static_model_class(config)
5561

tools/trainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def parse_args():
5151
parser = argparse.ArgumentParser(description='paddle-rec run')
5252
parser.add_argument("-m", "--config_yaml", type=str)
53-
parser.add_argument("--device", type=str)
53+
parser.add_argument("-o", "--opt", nargs='*', type=str)
5454
args = parser.parse_args()
5555
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
5656
args.config_yaml = get_abs_model(args.config_yaml)
@@ -63,15 +63,15 @@ def main(args):
6363
config = load_yaml(args.config_yaml)
6464
dy_model_class = load_dy_model_class(args.abs_dir)
6565
config["config_abs_dir"] = args.abs_dir
66+
# modify config from command
67+
if args.opt:
68+
for parameter in args.opt:
69+
parameter = parameter.strip()
70+
key, value = parameter.split("=")
71+
config[key] = value
6672

6773
# tools.vars
68-
if args.device is None:
69-
use_gpu = config.get("runner.use_gpu", True)
70-
elif args.device == "gpu":
71-
use_gpu = True
72-
else:
73-
use_gpu = False
74-
74+
use_gpu = config.get("runner.use_gpu", True)
7575
use_visual = config.get("runner.use_visual", False)
7676
train_data_dir = config.get("runner.train_data_dir", None)
7777
epochs = config.get("runner.epochs", None)

0 commit comments

Comments
 (0)