Skip to content

Commit 05e22ee

Browse files
esythanroot
authored andcommitted
static ps online trainer
1 parent 4845406 commit 05e22ee

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

tools/static_ps_online_trainer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42+
def parse_args():
43+
parser = argparse.ArgumentParser("PaddleRec train script")
44+
parser.add_argument(
45+
'-m',
46+
'--config_yaml',
47+
type=str,
48+
required=True,
49+
help='config file path')
50+
args = parser.parse_args()
51+
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
52+
yaml_helper = YamlHelper()
53+
config = yaml_helper.load_yaml(args.config_yaml)
54+
config["yaml_path"] = args.config_yaml
55+
config["config_abs_dir"] = args.abs_dir
56+
yaml_helper.print_yaml(config)
57+
return config
58+
59+
4260
class Main(object):
4361
def __init__(self, config):
4462
self.metrics = {}
@@ -461,19 +479,7 @@ def dataset_infer_loop(self, cur_dataset, day, pass_index,
461479

462480
if __name__ == "__main__":
463481
paddle.enable_static()
464-
parser = argparse.ArgumentParser("PaddleRec train script")
465-
parser.add_argument(
466-
'-m',
467-
'--config_yaml',
468-
type=str,
469-
required=True,
470-
help='config file path')
471-
args = parser.parse_args()
472-
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
473-
yaml_helper = YamlHelper()
474-
config = yaml_helper.load_yaml(args.config_yaml)
475-
config["yaml_path"] = args.config_yaml
476-
config["config_abs_dir"] = args.abs_dir
482+
config = parse_args()
477483
# os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
478484
benchmark_main = Main(config)
479485
benchmark_main.run()

0 commit comments

Comments
 (0)