Skip to content

Commit edde6aa

Browse files
FrostMLguoshengCS
andauthored
Static transformer support custom dataset (#757)
* fix static * update Co-authored-by: Guo Sheng <[email protected]>
1 parent cc7869e commit edde6aa

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

examples/machine_translation/transformer/static/predict.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def parse_args():
4343
action="store_true",
4444
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
4545
)
46+
parser.add_argument(
47+
"--test_file",
48+
default=None,
49+
type=str,
50+
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
51+
)
4652
args = parser.parse_args()
4753
return args
4854

@@ -136,7 +142,8 @@ def do_predict(args):
136142
yaml_file = ARGS.config
137143
with open(yaml_file, 'rt') as f:
138144
args = AttrDict(yaml.safe_load(f))
139-
pprint(args)
140145
args.benchmark = ARGS.benchmark
146+
args.test_file = ARGS.test_file
147+
pprint(args)
141148

142149
do_predict(args)

examples/machine_translation/transformer/static/train.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ def parse_args():
4545
default=None,
4646
type=int,
4747
help="The maximum iteration for training. ")
48+
parser.add_argument(
49+
"--train_file",
50+
nargs='+',
51+
default=None,
52+
type=str,
53+
help="The files for training, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to train. "
54+
)
55+
parser.add_argument(
56+
"--dev_file",
57+
nargs='+',
58+
default=None,
59+
type=str,
60+
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
61+
)
4862
args = parser.parse_args()
4963
return args
5064

@@ -274,10 +288,12 @@ def do_train(args):
274288
yaml_file = ARGS.config
275289
with open(yaml_file, 'rt') as f:
276290
args = AttrDict(yaml.safe_load(f))
277-
pprint(args)
278291
args.benchmark = ARGS.benchmark
279292
args.is_distributed = ARGS.distributed
280293
if ARGS.max_iter:
281294
args.max_iter = ARGS.max_iter
295+
args.train_file = ARGS.train_file
296+
args.dev_file = ARGS.dev_file
297+
pprint(args)
282298

283299
do_train(args)

0 commit comments

Comments
 (0)