Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
data/
exps/
results_modified/
venv/
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "data/gpt-MT"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please keep the data ignored and use any other directory?

finetune-experiments$ du -hs data/
11G     data/

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest a name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_data?

path = data/gpt-MT
url = https://github.com/microsoft/gpt-MT
1 change: 1 addition & 0 deletions data/gpt-MT
Submodule gpt-MT added at d64e21
1 change: 1 addition & 0 deletions data/wmt22
60 changes: 43 additions & 17 deletions decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path

import evaluate
from datasets import load_dataset
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.generation import BeamSearchDecoderOnlyOutput
from peft import PeftModel
Expand Down Expand Up @@ -71,7 +71,7 @@ def register(cls, parser):
type=str,
help="experiment directory: where to save the model",
)
parser.add_argument("--decode_subset", default="dev", type=str, help="Dataset subset of FLORES to decode.")
parser.add_argument("--decode_subset", default=["dev", "devtest", "wmt22test"], type=str, help="Dataset decode: dev for FLORES dev, devtest for FLORES devtest, test for FLORES test, wmt22test for WMT22 test.")
parser.add_argument("--decode_batch_size", default=1, type=int, help="Decoding batch size, should be small enough to accommodate all beams.")
parser.add_argument("--decode_beams", default=10, type=int, help="Number of beams to use during decoding. Set to 0 to avoid decoding in the training loop.")
parser.add_argument("--prompt", default="gracious", choices=["gracious", "basic"], type=str, help="Prompt style. Gracious uses a lot of words, basic uses [INST] [/INST].")
Expand Down Expand Up @@ -172,6 +172,42 @@ def __call__(self, ids, sources, references):
result['bleu'].append(sacrebleu.compute(predictions=[output], references=[ref])['score'])
return result

def report(self, output_path, dataset):
output_path.parent.mkdir(parents=True, exist_ok=True)
dataset.to_json(output_path, force_ascii=False)

# measure top-1 bleu
dataset_top1 = dataset.filter(lambda x: x["rank"] == 0, load_from_cache_file=False)
results = sacrebleu.compute(predictions=dataset_top1["hyp"], references=dataset_top1["ref"])
output_path.with_suffix('.results').write_text(json.dumps(results, ensure_ascii=False))
print(results)

return results

def decode_wmt22(self, exp: str):
# https://github.com/huggingface/datasets/issues/4709
dataset = load_dataset("text", data_files={
"en": "data/wmt22/test.en-uk.en",
"uk": "data/wmt22/test.en-uk.uk",
})

dataset = concatenate_datasets([dataset["en"].rename_column("text", "source"),
dataset["uk"].rename_column("text", "target")], axis=1)
dataset = dataset.add_column("id", list(range(len(dataset))))

columns = ["id", "source", "target"]
dataset = dataset.select_columns(columns)
dataset = dataset.map(
self,
batched=True,
batch_size=self.decode_batch_size,
input_columns=columns,
remove_columns=columns,
load_from_cache_file=False,
)

return self.report(Path(exp) / f"beam{self.decode_beams}.wmt22test.jsonl", dataset)

def decode_flores(self, exp: str, decode_subset: str, indices=None):
dataset = load_dataset(
"facebook/flores", "eng_Latn-ukr_Cyrl", trust_remote_code=True
Expand All @@ -189,25 +225,15 @@ def decode_flores(self, exp: str, decode_subset: str, indices=None):
load_from_cache_file=False,
)

exp = Path(exp)
exp.mkdir(parents=True, exist_ok=True)
output_path = exp / f"beam{self.decode_beams}.{decode_subset}.jsonl"

dataset.to_json(output_path, force_ascii=False)

# measure top-1 bleu
dataset_top1 = dataset.filter(lambda x: x["rank"] == 0, load_from_cache_file=False)
results = sacrebleu.compute(predictions=dataset_top1["hyp"], references=dataset_top1["ref"])
output_path.with_suffix('.results').write_text(json.dumps(results, ensure_ascii=False))
print(results)

return results

return self.report(Path(exp) / f"beam{self.decode_beams}.{decode_subset}.jsonl", dataset)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
BatchTranslator.register(parser)
args = parser.parse_args()

translator = BatchTranslator.from_args(args)
translator.decode_flores(exp=args.exp, decode_subset=args.decode_subset)
if "wmt22test" == args.decode_subset:
translator.decode_wmt22(exp=args.exp)
else:
translator.decode_flores(exp=args.exp, decode_subset=args.decode_subset)