Skip to content

Commit 1795f68

Browse files
author
gongel
committed
fix: add more args
1 parent 18b9ed6 commit 1795f68

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

examples/machine_translation/transformer/predict_beamsearch_v2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ def parse_args():
1616
default="./configs/transformer.base.yaml",
1717
type=str,
1818
help="Path of the config file. ")
19+
parser.add_argument(
20+
"--benchmark",
21+
action="store_true",
22+
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
23+
)
24+
parser.add_argument(
25+
"--test_file",
26+
default=None,
27+
type=str,
28+
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
29+
)
1930
args = parser.parse_args()
2031
return args
2132

@@ -111,6 +122,8 @@ def do_predict(args):
111122
yaml_file = ARGS.config
112123
with open(yaml_file, 'rt') as f:
113124
args = AttrDict(yaml.safe_load(f))
125+
args.benchmark = ARGS.benchmark
126+
args.test_file = ARGS.test_file
114127
pprint(args)
115128

116129
do_predict(args)

0 commit comments

Comments
 (0)