Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit c33e62e

Browse files
hutao965Hu
andauthored
[FEATURE]Horovod support for training transformer + add mirror data for wmt (PART 1) (#1284)
* set default shuffle=True for boundedbudgetsampler * fix * fix log condition * use horovod to train transformer * fix * add mirror wmt dataset * fix * rename wmt.txt to wmt.json and remove part of urls * fix * tuning params * use get_repo_url() * update average checkpoint cli * paste result of transformer large * fix * fix logging in train_transformer * fix * fix * fix * add transformer base config Co-authored-by: Hu <[email protected]>
1 parent ded0f99 commit c33e62e

File tree

8 files changed

+263
-68
lines changed

8 files changed

+263
-68
lines changed

scripts/datasets/machine_translation/prepare_wmt.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import functools
88
import tarfile
99
import gzip
10+
import json
1011
from xml.etree import ElementTree
1112
from gluonnlp.data.filtering import ProfanityFilter
1213
from gluonnlp.utils.misc import file_line_number, download, load_checksum_stats
13-
from gluonnlp.base import get_data_home_dir
14+
from gluonnlp.base import get_data_home_dir, get_repo_url
1415
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY
1516

1617
# The datasets are provided by WMT2014-WMT2019 and can be freely used for research purposes.
@@ -336,6 +337,15 @@
336337
}
337338
}
338339

340+
with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.json')) as wmt_mirror_map_f:
341+
_WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f)
342+
343+
def _download_with_mirror(url, path, sha1_hash):
344+
return download(
345+
get_repo_url() + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url,
346+
path=path,
347+
sha1_hash=sha1_hash
348+
)
339349

340350
def _clean_space(s: str):
341351
"""Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one.
@@ -626,7 +636,11 @@ def fetch_mono_dataset(selection: Union[str, List[str], List[List[str]]],
626636
save_path_l = [path] + selection + [matched_lang, original_filename]
627637
else:
628638
save_path_l = [path] + selection + [original_filename]
629-
download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash)
639+
download_fname = _download_with_mirror(
640+
url,
641+
path=os.path.join(*save_path_l),
642+
sha1_hash=sha1_hash
643+
)
630644
download_fname_l.append(download_fname)
631645
if len(download_fname_l) > 1:
632646
data_path = concatenate_files(download_fname_l)
@@ -792,7 +806,11 @@ def fetch_wmt_parallel_dataset(selection: Union[str, List[str], List[List[str]]]
792806
save_path_l = [path] + selection + [matched_pair, original_filename]
793807
else:
794808
save_path_l = [path] + selection + [original_filename]
795-
download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash)
809+
download_fname = _download_with_mirror(
810+
url,
811+
path=os.path.join(*save_path_l),
812+
sha1_hash=sha1_hash
813+
)
796814
download_fname_l.append(download_fname)
797815
if len(download_fname_l) > 1:
798816
data_path = concatenate_files(download_fname_l)

scripts/datasets/machine_translation/wmt2014_ende.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
3434
--tgt-corpus dev.raw.${TGT} \
3535
--min-num-words 1 \
3636
--max-num-words 100 \
37-
--max-ratio 1.5 \
3837
--src-save-path dev.tok.${SRC} \
3938
--tgt-save-path dev.tok.${TGT}
4039

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"http://www.statmt.org/europarl/v7/cs-en.tgz" : "datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz",
3+
"http://www.statmt.org/europarl/v7/de-en.tgz" : "datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz",
4+
"http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz",
5+
"http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz",
6+
"http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz",
7+
"http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz",
8+
"http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz",
9+
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz",
10+
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz",
11+
"http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz",
12+
"http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz",
13+
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz",
14+
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz",
15+
"http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz",
16+
"http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz",
17+
"http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz",
18+
"http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz",
19+
"http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz",
20+
"http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz",
21+
"http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz",
22+
"http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz",
23+
"http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz",
24+
"http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz",
25+
"http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz",
26+
"http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz",
27+
"http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz",
28+
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00",
29+
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01",
30+
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02",
31+
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00",
32+
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01",
33+
"http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz",
34+
"http://data.statmt.org/wmt19/translation-task/dev.tgz" : "datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz",
35+
"http://data.statmt.org/wmt19/translation-task/test.tgz" : "datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz",
36+
"http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz",
37+
"http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz",
38+
"http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz",
39+
"http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz",
40+
"http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz",
41+
"http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz",
42+
"http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz",
43+
"http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz",
44+
"http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz",
45+
"http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz",
46+
"http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz",
47+
"http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz"
48+
}

scripts/machine_translation/README.md

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,113 @@ You may first run the following command in [datasets/machine_translation](../dat
1010
bash wmt2014_ende.sh yttm
1111
```
1212

13-
Then, you can run the experiment, we use the
14-
"transformer_base" configuration.
13+
Then, you can run the experiment.
14+
For "transformer_base" configuration
1515

16+
# TODO
1617
```bash
1718
SUBWORD_MODEL=yttm
19+
SRC=en
20+
TGT=de
21+
datapath=../datasets/machine_translation
1822
python train_transformer.py \
19-
--train_src_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.en \
20-
--train_tgt_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.de \
21-
--dev_src_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.en \
22-
--dev_tgt_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.de \
23+
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
24+
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
25+
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
26+
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
27+
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
28+
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
29+
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
30+
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
31+
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
32+
--cfg transformer_base \
33+
--lr 0.002 \
34+
--batch_size 2700 \
35+
--num_averages 5 \
36+
--warmup_steps 4000 \
37+
--warmup_init_lr 0.0 \
38+
--seed 123 \
39+
--gpus 0,1,2,3
40+
```
41+
42+
Use the average_checkpoint cli to average the last 10 checkpoints
43+
44+
```bash
45+
gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \
46+
--begin 21 \
47+
--end 30 \
48+
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
49+
```
50+
51+
52+
Use the following command to inference/evaluate the Transformer model:
53+
54+
```bash
55+
SUBWORD_MODEL=yttm
56+
python evaluate_transformer.py \
57+
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
58+
--src_lang en \
59+
--tgt_lang de \
60+
--cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
61+
--src_tokenizer ${SUBWORD_MODEL} \
62+
--tgt_tokenizer ${SUBWORD_MODEL} \
2363
--src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
24-
--src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
2564
--tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
65+
--src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
2666
--tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
27-
--save_dir transformer_wmt2014_ende_${SUBWORD_MODEL} \
28-
--cfg transformer_base \
29-
--lr 0.002 \
67+
--src_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.en \
68+
--tgt_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.de
69+
```
70+
71+
72+
73+
For "transformer_wmt_en_de_big" configuration
74+
75+
```bash
76+
SUBWORD_MODEL=yttm
77+
SRC=en
78+
TGT=de
79+
datapath=../datasets/machine_translation
80+
python train_transformer.py \
81+
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
82+
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
83+
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
84+
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
85+
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
86+
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
87+
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
88+
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
89+
--save_dir transformer_big_wmt2014_en_de_${SUBWORD_ALGO} \
90+
--cfg transformer_wmt_en_de_big \
91+
--lr 0.001 \
92+
--sampler BoundedBudgetSampler \
93+
--max_num_tokens 3584 \
94+
--max_update 15000 \
3095
--warmup_steps 4000 \
3196
--warmup_init_lr 0.0 \
3297
--seed 123 \
3398
--gpus 0,1,2,3
3499
```
35100

101+
Use the average_checkpoint cli to average the last 10 checkpoints
102+
103+
```bash
104+
gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \
105+
--begin 21 \
106+
--end 30 \
107+
--save-path transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
108+
```
109+
110+
36111
Use the following command to inference/evaluate the Transformer model:
37112

38113
```bash
39114
SUBWORD_MODEL=yttm
40115
python evaluate_transformer.py \
41-
--param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average.params \
116+
--param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
42117
--src_lang en \
43118
--tgt_lang de \
44-
--cfg transformer_wmt2014_ende_${SUBWORD_MODEL}/config.yml \
119+
--cfg transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
45120
--src_tokenizer ${SUBWORD_MODEL} \
46121
--tgt_tokenizer ${SUBWORD_MODEL} \
47122
--src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
@@ -59,6 +134,14 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU):
59134

60135
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
61136
|---------------|------------|-------------|-------------|--------------|-------------|
62-
| yttm | | 26.63 | 26.73 | | - |
137+
| yttm | | - | - | - | - |
138+
| hf_bpe | | - | - | - | - |
139+
| spm | | - | - | - | - |
140+
141+
- transformer_wmt_en_de_big
142+
143+
| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
144+
|---------------|------------|-------------|-------------|--------------|-------------|
145+
| yttm | | 27.99 | - | - | - |
63146
| hf_bpe | | - | - | - | - |
64147
| spm | | - | - | - | - |

0 commit comments

Comments
 (0)