Skip to content

Commit 62a86ce

Browse files
authored
Merge pull request #23 from RTIInternational/backtranslation
Add a data augmentation method based on backtranslation along with a benchmark.
2 parents 981a614 + 78aa36b commit 62a86ce

File tree

13 files changed

+573
-9
lines changed

13 files changed

+573
-9
lines changed

benchmark/BENCHMARK_SPECS.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,13 @@
644644
params: {}
645645
- augment_name: "BERTMaskedLM"
646646
params: {}
647+
- augment_name: "MarianMT"
648+
params:
649+
# Need as many languages as the largest multiplier used above
650+
# Top x available in MarianMT by descending popularity on Wikipedia as a
651+
# rough proxy for best-supported languages
652+
# https://en.wikipedia.org/wiki/List_of_Wikipedias
653+
target_languages: ["french", "german", "japanese", "russian", "italian", "portugese", "dutch", "indonesian", "ukrainian", "swedish"]
647654

648655
- scenario: "document_windowing"
649656
params:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Results: MarianMT
2+
| | percent | multiplier | Weighted F1 Score | Weighted Precision Score | Weighted Recall Score | Accuracy |
3+
|---:|----------:|-------------:|--------------------:|---------------------------:|------------------------:|-----------:|
4+
| 0 | 0.005 | 0 | 0.623019 | 0.651554 | 0.63396 | 0.63396 |
5+
| 1 | 0.005 | 1 | 0.333422 | 0.75001 | 0.50004 | 0.50004 |
6+
| 2 | 0.005 | 5 | 0.658066 | 0.658107 | 0.65808 | 0.65808 |
7+
| 3 | 0.005 | 10 | 0.646764 | 0.67772 | 0.65704 | 0.65704 |
8+
| 4 | 0.05 | 0 | 0.798386 | 0.798483 | 0.7984 | 0.7984 |
9+
| 5 | 0.05 | 1 | 0.794299 | 0.794979 | 0.7944 | 0.7944 |
10+
| 6 | 0.05 | 5 | 0.808468 | 0.808556 | 0.80848 | 0.80848 |
11+
| 7 | 0.05 | 10 | 0.807779 | 0.807932 | 0.8078 | 0.8078 |
12+
| 8 | 0.33 | 0 | 0.85779 | 0.857904 | 0.8578 | 0.8578 |
13+
| 9 | 0.33 | 1 | 0.853155 | 0.85363 | 0.8532 | 0.8532 |
14+
| 10 | 0.33 | 5 | 0.856863 | 0.857054 | 0.85688 | 0.85688 |
15+
| 11 | 0.75 | 0 | 0.876512 | 0.876623 | 0.87652 | 0.87652 |
16+
| 12 | 0.75 | 1 | 0.871752 | 0.872361 | 0.8718 | 0.8718 |
17+
| 13 | 0.75 | 5 | 0.871178 | 0.871456 | 0.8712 | 0.8712 |
18+
19+
![Results](MarianMT/plot.png)
20+
---
47.5 KB
Loading
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"percent_multipliers": [[0.005, 0], [0.005, 1], [0.005, 5], [0.005, 10], [0.05, 0], [0.05, 1], [0.05, 5], [0.05, 10], [0.33, 0], [0.33, 1], [0.33, 5], [0.75, 0], [0.75, 1], [0.75, 5]], "model_name": "FastText", "param_grid": {"word_ngrams": [1], "autotune_duration": [120]}, "preprocess_func": "fasttext_preprocess", "augment_probability": 0.15, "augment_name": "MarianMT", "params": {"target_languages": ["french", "german", "japanese", "russian", "italian", "portugese", "dutch", "indonesian", "ukrainian", "swedish"]}}

benchmark/benchmark_output/data_augmentation/data_augmentation.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,23 @@
5858

5959
![Results](BERTMaskedLM/plot.png)
6060
---
61+
# Results: MarianMT
62+
| | percent | multiplier | Weighted F1 Score | Weighted Precision Score | Weighted Recall Score | Accuracy |
63+
|---:|----------:|-------------:|--------------------:|---------------------------:|------------------------:|-----------:|
64+
| 0 | 0.005 | 0 | 0.623019 | 0.651554 | 0.63396 | 0.63396 |
65+
| 1 | 0.005 | 1 | 0.333422 | 0.75001 | 0.50004 | 0.50004 |
66+
| 2 | 0.005 | 5 | 0.658066 | 0.658107 | 0.65808 | 0.65808 |
67+
| 3 | 0.005 | 10 | 0.646764 | 0.67772 | 0.65704 | 0.65704 |
68+
| 4 | 0.05 | 0 | 0.798386 | 0.798483 | 0.7984 | 0.7984 |
69+
| 5 | 0.05 | 1 | 0.794299 | 0.794979 | 0.7944 | 0.7944 |
70+
| 6 | 0.05 | 5 | 0.808468 | 0.808556 | 0.80848 | 0.80848 |
71+
| 7 | 0.05 | 10 | 0.807779 | 0.807932 | 0.8078 | 0.8078 |
72+
| 8 | 0.33 | 0 | 0.85779 | 0.857904 | 0.8578 | 0.8578 |
73+
| 9 | 0.33 | 1 | 0.853155 | 0.85363 | 0.8532 | 0.8532 |
74+
| 10 | 0.33 | 5 | 0.856863 | 0.857054 | 0.85688 | 0.85688 |
75+
| 11 | 0.75 | 0 | 0.876512 | 0.876623 | 0.87652 | 0.87652 |
76+
| 12 | 0.75 | 1 | 0.871752 | 0.872361 | 0.8718 | 0.8718 |
77+
| 13 | 0.75 | 5 | 0.871178 | 0.871456 | 0.8712 | 0.8712 |
78+
79+
![Results](MarianMT/plot.png)
80+
---

benchmark/run_benchmarks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
import os
23
from pathlib import Path
34
from typing import Any, Dict, List
45

56
import click
67
import yaml
78

89
import gobbli
10+
from benchmark_util import BENCHMARK_DATA_DIR
911
from scenario import (
1012
ClassImbalanceScenario,
1113
DataAugmentationScenario,
@@ -104,6 +106,10 @@ def run(
104106
debug: bool,
105107
raise_exceptions: bool,
106108
):
109+
# Make sure all models run outside of experiments create their data under the
110+
# assigned benchmark directory
111+
os.environ["GOBBLI_DIR"] = str(BENCHMARK_DATA_DIR)
112+
107113
logging.basicConfig(
108114
level=log_level, format="[%(asctime)s] %(levelname)s - %(name)s: %(message)s"
109115
)

benchmark/scenario.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import gobbli.model
2121
from benchmark_util import (
22+
BENCHMARK_DATA_DIR,
2223
PREPROCESS_FUNCS,
2324
StdoutCatcher,
2425
assert_param_required,
@@ -41,11 +42,27 @@
4142
make_document_windows,
4243
pool_document_windows,
4344
)
45+
from gobbli.model.base import BaseModel
4446
from gobbli.util import TokenizeMethod, assert_in, assert_type, pred_prob_to_pred_label
4547

4648
LOGGER = logging.getLogger(__name__)
4749

4850

51+
def get_model_run_params() -> Dict[str, Any]:
52+
"""
53+
See also :func:`run_benchmark_experiment`, since there's some duplication between there
54+
and here (that function initializes its own models indirectly via
55+
:class:`gobbli.experiment.ClassificationExperiment`
56+
57+
Returns:
58+
Parameters that should be passed to any gobbli model used as part of benchmarks.
59+
"""
60+
return {
61+
"use_gpu": os.getenv("GOBBLI_USE_GPU") is not None,
62+
"nvidia_visible_devices": os.getenv("NVIDIA_VISIBLE_DEVICES", ""),
63+
}
64+
65+
4966
class BaseRun(ABC):
5067
"""
5168
Base class for a single run within a benchmark scenario.
@@ -426,13 +443,10 @@ def _do_run(self, run: ModelEmbeddingRun, run_output_dir: Path) -> str:
426443
stdout_catcher = StdoutCatcher()
427444
with stdout_catcher:
428445
# Construct the dict of kwargs up-front so each run can override the "use_gpu"
429-
# option if necessary -- ex. for models like spaCy which have trouble controlling
430-
# memory usage on the GPU and don't gain much benefit from it
431-
model_kwargs = {
432-
"use_gpu": os.getenv("GOBBLI_USE_GPU") is not None,
433-
"nvidia_visible_devices": os.getenv("NVIDIA_VISIBLE_DEVICES", ""),
434-
**run.model_params,
435-
}
446+
# option if necessary using its model params -- ex. for models like spaCy
447+
# which have trouble controlling memory usage on the GPU and don't gain
448+
# much benefit from it
449+
model_kwargs = {**get_model_run_params(), **run.model_params}
436450
model = model_cls(**model_kwargs)
437451
model.build()
438452

@@ -778,7 +792,20 @@ def _do_run(self, run: AugmentRun, run_output_dir: Path) -> str:
778792

779793
assert_valid_augment(run.augment_name)
780794
augment_cls = getattr(gobbli.augment, run.augment_name)
781-
augment_obj = augment_cls(**run.params)
795+
796+
model_run_params: Dict[str, Any] = {}
797+
if issubclass(augment_cls, BaseModel):
798+
# If the augment method is also a gobbli model (and will be mounting files back-
799+
# and-forth with Docker), we need to make sure it has the proper params
800+
# applied ex. to store data in the correct place and use GPU(s)
801+
model_run_params = get_model_run_params()
802+
803+
augment_obj = augment_cls(**run.params, **model_run_params)
804+
805+
# Some augmentation methods are also models, which need to be built
806+
# beforehand
807+
if isinstance(augment_obj, BaseModel):
808+
augment_obj.build()
782809

783810
all_results = []
784811

gobbli/augment/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from gobbli.augment.bert import BERTMaskedLM
2+
from gobbli.augment.marian import MarianMT
23
from gobbli.augment.word2vec import Word2Vec
34
from gobbli.augment.wordnet import WordNet
45

5-
__all__ = ["BERTMaskedLM", "Word2Vec", "WordNet"]
6+
__all__ = ["BERTMaskedLM", "Word2Vec", "WordNet", "MarianMT"]

gobbli/augment/marian/Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
FROM pytorch/pytorch:1.3-cuda10.1-cudnn7-runtime
2+
3+
RUN pip install transformers==2.9.1
4+
5+
COPY ./src /code/marian
6+
WORKDIR /code/marian

gobbli/augment/marian/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import MarianMT
2+
3+
__all__ = ["MarianMT"]

0 commit comments

Comments
 (0)