Skip to content
This repository was archived by the owner on Nov 8, 2022. It is now read-only.

Commit 0eec65d

Browse files
author
Peter Izsak
authored
Refactor GLUE data loaders (#138)
Refactor GLUE data loaders and misc utils.
1 parent 4541ae5 commit 0eec65d

File tree

4 files changed

+68
-30
lines changed

4 files changed

+68
-30
lines changed

nlp_architect/data/glue_tasks.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import logging
1717
import os
1818

19+
from sklearn.metrics import matthews_corrcoef
20+
1921
from nlp_architect.data.sequence_classification import SequenceClsInputExample
2022
from nlp_architect.data.utils import DataProcessor, Task, read_tsv
23+
from nlp_architect.utils.metrics import acc_and_f1, pearson_and_spearman, simple_accuracy
2124

2225
logger = logging.getLogger(__name__)
2326

@@ -539,6 +542,31 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
539542
}
540543

541544

545+
# GLUE task metrics
546+
def get_metric_fn(task_name):
547+
if task_name == "cola":
548+
return lambda p, l: {"mcc": matthews_corrcoef(p, l)}
549+
if task_name == "sst-2":
550+
return lambda p, l: {"acc": simple_accuracy(p, l)}
551+
if task_name == "mrpc":
552+
return acc_and_f1
553+
if task_name == "sts-b":
554+
return pearson_and_spearman
555+
if task_name == "qqp":
556+
return acc_and_f1
557+
if task_name == "mnli":
558+
return lambda p, l: {"acc": simple_accuracy(p, l)}
559+
if task_name == "mnli-mm":
560+
return lambda p, l: {"acc": simple_accuracy(p, l)}
561+
if task_name == "qnli":
562+
return lambda p, l: {"acc": simple_accuracy(p, l)}
563+
if task_name == "rte":
564+
return lambda p, l: {"acc": simple_accuracy(p, l)}
565+
if task_name == "wnli":
566+
return lambda p, l: {"acc": simple_accuracy(p, l)}
567+
raise KeyError(task_name)
568+
569+
542570
def get_glue_task(task_name: str, data_dir: str = None):
543571
"""Return a GLUE task object
544572
Args:
@@ -551,6 +579,9 @@ def get_glue_task(task_name: str, data_dir: str = None):
551579
raise ValueError("Task not found: {}".format(task_name))
552580
task_processor = processors[task_name]()
553581
if data_dir is None:
554-
data_dir = os.path.join(os.environ["GLUE_DIR"], DEFAULT_FOLDER_NAMES[task_name])
582+
try:
583+
data_dir = os.path.join(os.environ["GLUE_DIR"], DEFAULT_FOLDER_NAMES[task_name])
584+
except Exception:
585+
data_dir = None
555586
task_type = output_modes[task_name]
556587
return Task(task_name, task_processor, data_dir, task_type)

nlp_architect/data/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class InputExample(ABC):
3030
def __init__(self, guid: str, text, label=None):
3131
self.guid = guid
3232
self.text = text
33+
self.text_a = text # for compatibility with trasformer library
3334
self.label = label
3435

3536

@@ -181,3 +182,30 @@ def split_column_dataset(
181182
second_data = selected_lines[first_count:]
182183
write_column_tagged_file(out_folder + os.sep + first_filename, first_data)
183184
write_column_tagged_file(out_folder + os.sep + second_filename, second_data)
185+
186+
187+
def get_cached_filepath(data_dir, model_name, seq_length, task_name, set_type="train"):
188+
"""get cached file name
189+
190+
Arguments:
191+
data_dir {str} -- data directory string
192+
model_name {str} -- model name
193+
seq_length {int} -- max sequence length
194+
task_name {str} -- name of task
195+
196+
Keyword Arguments:
197+
set_type {str} -- set type (choose from train/dev/test) (default: {"train"})
198+
199+
Returns:
200+
str -- cached filename
201+
"""
202+
cached_features_file = os.path.join(
203+
data_dir,
204+
"cached_{}_{}_{}_{}".format(
205+
set_type,
206+
list(filter(None, model_name.split("/"))).pop(),
207+
str(seq_length),
208+
str(task_name),
209+
),
210+
)
211+
return cached_features_file

nlp_architect/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,29 @@ class TrainableModel(ABC):
2626
def convert_to_tensors(self, *args, **kwargs):
2727
"""convert any chosen input to valid model format of tensors
2828
"""
29+
raise NotImplementedError
2930

3031
def get_logits(self, *args, **kwargs):
3132
"""get model logits from given input
3233
"""
34+
raise NotImplementedError
3335

3436
def train(self, *args, **kwargs):
3537
"""train the model
3638
"""
39+
raise NotImplementedError
3740

3841
def inference(self, *args, **kwargs):
3942
"""run inference
4043
"""
44+
raise NotImplementedError
4145

4246
def save_model(self, *args, **kwargs):
4347
"""save the model
4448
"""
49+
...
4550

4651
def load_model(self, *args, **kwargs):
4752
"""load a model
4853
"""
54+
...

nlp_architect/procedures/transformers/glue.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,15 @@
1818
import logging
1919
import os
2020

21-
from sklearn.metrics import matthews_corrcoef
2221
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
2322

24-
from nlp_architect.data.glue_tasks import get_glue_task, processors
23+
from nlp_architect.data.glue_tasks import get_glue_task, get_metric_fn, processors
2524
from nlp_architect.models.transformers import TransformerSequenceClassifier
26-
from nlp_architect.nn.torch import setup_backend, set_seed
25+
from nlp_architect.nn.torch import set_seed, setup_backend
2726
from nlp_architect.procedures.procedure import Procedure
2827
from nlp_architect.procedures.registry import register_inference_cmd, register_train_cmd
2928
from nlp_architect.procedures.transformers.base import create_base_args, inference_args, train_args
3029
from nlp_architect.utils.io import prepare_output_path
31-
from nlp_architect.utils.metrics import acc_and_f1, pearson_and_spearman, simple_accuracy
3230

3331
logger = logging.getLogger(__name__)
3432

@@ -168,28 +166,3 @@ def do_inference(args):
168166
with io.open(os.path.join(args.output_dir, "output.txt"), "w", encoding="utf-8") as fw:
169167
for p in preds:
170168
fw.write("{}\n".format(p))
171-
172-
173-
# GLUE task metrics
174-
def get_metric_fn(task_name):
175-
if task_name == "cola":
176-
return lambda p, l: {"mcc": matthews_corrcoef(p, l)}
177-
if task_name == "sst-2":
178-
return lambda p, l: {"acc": simple_accuracy(p, l)}
179-
if task_name == "mrpc":
180-
return acc_and_f1
181-
if task_name == "sts-b":
182-
return pearson_and_spearman
183-
if task_name == "qqp":
184-
return acc_and_f1
185-
if task_name == "mnli":
186-
return lambda p, l: {"acc": simple_accuracy(p, l)}
187-
if task_name == "mnli-mm":
188-
return lambda p, l: {"acc": simple_accuracy(p, l)}
189-
if task_name == "qnli":
190-
return lambda p, l: {"acc": simple_accuracy(p, l)}
191-
if task_name == "rte":
192-
return lambda p, l: {"acc": simple_accuracy(p, l)}
193-
if task_name == "wnli":
194-
return lambda p, l: {"acc": simple_accuracy(p, l)}
195-
raise KeyError(task_name)

0 commit comments

Comments
 (0)