Skip to content

Commit 2db1b16

Browse files
committed
Finetuning Llama3.1 model to perform translation of text-to-gloss and gloss-to-text
1 parent 43ac3ac commit 2db1b16

File tree

3 files changed

+368
-0
lines changed

3 files changed

+368
-0
lines changed

llama/data_selection.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
from collections import Counter
5+
from torch import Tensor
6+
import io
7+
import time
8+
import os
9+
import pandas as pd
10+
import json
11+
from datetime import datetime
12+
from transformers import AutoTokenizer
13+
from torch.utils.data import Dataset, DataLoader
14+
from sklearn.model_selection import train_test_split
15+
from .utils import Translation
16+
17+
features_names = ["maingloss"]
18+
mms_directories = [
19+
("mms-subset91", 'latin-1'),
20+
("modified/location/mms", 'utf-8'),
21+
("modified/platform/mms", 'utf-8'),
22+
("modified/time/mms", 'utf-8'),
23+
("modified/train_name/mms", 'utf-8'),
24+
]
25+
text_directories = [
26+
("annotations_full/annotations", 'latin-1'),
27+
("modified/location/text", 'utf-8'),
28+
("modified/platform/text", 'utf-8'),
29+
("modified/time/text", 'utf-8'),
30+
("modified/train_name/text", 'utf-8'),
31+
]
32+
33+
def read(text_info, mms_info, translation):
34+
data_list = []
35+
(text_directory, text_encoding) = text_info
36+
print("text_directory: ", text_directory)
37+
(mms_directory, mms_encoding) = mms_info
38+
for filenumber in os.listdir(text_directory):
39+
f = os.path.join(mms_directory, filenumber+".mms")
40+
try:
41+
df = pd.read_csv(f, encoding=mms_encoding)
42+
except FileNotFoundError as e:
43+
print(f"WARNING: Text file exists while mms file does not, skipping: {e}")
44+
continue
45+
46+
text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~")
47+
file = open(text_address, encoding=text_encoding)
48+
lines = file.readlines()
49+
text_line = ""
50+
for i, text_data in enumerate(lines):
51+
if i>0:
52+
text_line = text_line + " " + text_data.replace("\n", "").split(";")[2]
53+
else:
54+
text_line = text_line + text_data.replace("\n", "").split(";")[2]
55+
for feature in features_names:
56+
gloss_line = " ".join(df["maingloss"].tolist())
57+
if translation == Translation.TextToGloss:
58+
combined_line = f"{text_line} ###> {gloss_line}" # text to gloss
59+
elif translation == Translation.GlossToText:
60+
combined_line = f"{gloss_line} ###> {text_line}" # gloss to text
61+
else:
62+
raise ValueError("Invalid translation")
63+
data_list.append({"text": combined_line})
64+
return data_list
65+
66+
def create_datasets(translation):
67+
data_list_only_original = []
68+
data_list_only_modified = []
69+
for i, text_info in enumerate(text_directories):
70+
mms_info = mms_directories[i]
71+
data_list_one = read(text_info, mms_info, translation)
72+
if i <= 0:
73+
data_list_only_original += data_list_one
74+
else:
75+
data_list_only_modified += data_list_one
76+
77+
data_list_full = data_list_only_original + data_list_only_modified
78+
79+
80+
train_data, temp_data = train_test_split(data_list_full, test_size=0.2, random_state=42)
81+
val_data, test_data = train_test_split(temp_data, test_size=1/3, random_state=42)
82+
83+
84+
if translation == Translation.TextToGloss:
85+
translation_dir = "t2g_llama"
86+
elif translation == Translation.GlossToText:
87+
translation_dir = "g2t_llama"
88+
else:
89+
raise ValueError("Invalid translation")
90+
with open(f"train_data_{translation_dir}.json", "w") as f:
91+
json.dump(train_data, f)
92+
93+
with open(f"val_data_{translation_dir}.json", "w") as f:
94+
json.dump(val_data, f)
95+
96+
with open(f"test_data_{translation_dir}.json", "w") as f:
97+
json.dump(test_data, f)

llama/fine_tune.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import transformers
5+
from transformers import AutoTokenizer, AutoModelForCausalLM
6+
import pickle
7+
import os
8+
from sacrebleu.metrics import BLEU
9+
from .data_selection import *
10+
from pathlib import Path
11+
from torch.utils.data import DataLoader
12+
import time
13+
from enum import Enum, verify, UNIQUE
14+
from transformers import BitsAndBytesConfig
15+
from huggingface_hub import login
16+
from datasets import Dataset, load_dataset
17+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
18+
from trl import SFTTrainer
19+
20+
hf_access_token = os.getenv("HF_ACCESS_TOKEN")
21+
assert hf_access_token is not None, "You need to set the Hugging Face access token environment variable: export HF_ACCESS_TOKEN=hf_TODO"
22+
23+
login(token = hf_access_token)
24+
25+
def training(translation):
26+
27+
create_datasets(translation)
28+
29+
if translation == Translation.TextToGloss:
30+
translation_dir = "t2g_llama"
31+
elif translation == Translation.GlossToText:
32+
translation_dir = "g2t_llama"
33+
else:
34+
raise ValueError("Invalid translation")
35+
36+
37+
with open(f"train_data_{translation_dir}.json", "r") as f:
38+
train_data = json.load(f)
39+
40+
with open(f"val_data_{translation_dir}.json", "r") as f:
41+
val_data = json.load(f)
42+
43+
train_dataset = Dataset.from_list(train_data)
44+
val_dataset = Dataset.from_list(val_data)
45+
46+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47+
torch.cuda.empty_cache()
48+
cache_dir = "/ds/videos/AVASAG/cache"
49+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
50+
51+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_access_token, cache_dir=cache_dir, add_eos_token=True)
52+
# Set padding token
53+
tokenizer.pad_token = tokenizer.eos_token
54+
tokenizer.padding_side = "right"
55+
56+
bnb_config = BitsAndBytesConfig(
57+
load_in_4bit=True,
58+
bnb_4bit_use_double_quant=True,
59+
bnb_4bit_quant_type="nf4",
60+
bnb_4bit_compute_dtype=torch.bfloat16
61+
)
62+
63+
save_folder = os.path.join("/ds/videos/AVASAG/llama_finetune/", translation_dir)
64+
sft_model_name = os.path.join(save_folder, "llama-31-it-8b-sft")
65+
merged_model_name=os.path.join(save_folder, "llama-31-it-8b-sft-merged")
66+
67+
model = AutoModelForCausalLM.from_pretrained(
68+
model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config, token=hf_access_token, cache_dir=cache_dir)
69+
70+
model = prepare_model_for_kbit_training(model)
71+
72+
modules = ["down_proj","up_proj","gate_proj"]
73+
74+
lora_config = LoraConfig(
75+
r=64,
76+
lora_alpha=32,
77+
target_modules=modules,
78+
lora_dropout=0.05,
79+
bias="none",
80+
task_type="CAUSAL_LM"
81+
)
82+
83+
model = get_peft_model(model, lora_config)
84+
85+
trainable, total = model.get_nb_trainable_parameters()
86+
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")
87+
88+
tokenizer.pad_token = tokenizer.eos_token
89+
torch.cuda.empty_cache()
90+
91+
trainer = SFTTrainer(
92+
model=model,
93+
train_dataset=train_dataset,
94+
eval_dataset=val_dataset,
95+
dataset_text_field="text",
96+
peft_config=lora_config,
97+
args=transformers.TrainingArguments(
98+
report_to=[], # Disable logging
99+
per_device_train_batch_size=1,
100+
gradient_accumulation_steps=4,
101+
warmup_ratio=0.03,
102+
max_steps=1000,
103+
learning_rate=2e-5,
104+
logging_steps=1,
105+
output_dir="/ds/videos/AVASAG/llama_finetune/outputs_{translation_dir}",
106+
optim="paged_adamw_8bit",
107+
save_strategy="epoch",
108+
ddp_find_unused_parameters=False,
109+
),
110+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
111+
)
112+
model.config.use_cache = False
113+
trainer.train()
114+
115+
trainer.model.save_pretrained(sft_model_name)
116+
117+
base_model = AutoModelForCausalLM.from_pretrained(
118+
model_id,
119+
low_cpu_mem_usage=True,
120+
return_dict=True,
121+
torch_dtype=torch.float16,
122+
device_map="auto",
123+
)
124+
merged_model = PeftModel.from_pretrained(base_model, sft_model_name)
125+
merged_model = merged_model.merge_and_unload()
126+
127+
merged_model.save_pretrained(merged_model_name, safe_serialization=True)
128+
tokenizer.save_pretrained(merged_model_name)
129+
130+
131+
if __name__ == "__main__":
132+
import sys
133+
134+
if len(sys.argv) != 2:
135+
print("Usage: python k_fold.py [--textTogloss|--glossTotext]")
136+
sys.exit(1)
137+
138+
if sys.argv[1] == "--textTogloss":
139+
print("Translating from Text to Gloss")
140+
translation = Translation.TextToGloss
141+
elif sys.argv[1] == "--glossTotext":
142+
print("Translating from Gloss to Text ")
143+
translation = Translation.GlossToText
144+
else:
145+
print("You have to specify either --textTogloss or --glossTotext as an argument.")
146+
sys.exit(1)
147+
148+
training(translation)

llama/inference.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5+
import pickle
6+
import os
7+
from sacrebleu.metrics import BLEU
8+
from pathlib import Path
9+
from torch.utils.data import DataLoader
10+
import time
11+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
12+
from trl import SFTTrainer
13+
import bitsandbytes as bnb
14+
import transformers
15+
import json
16+
import pandas as pd
17+
from datasets import Dataset, load_dataset
18+
from .utils import Translation
19+
20+
21+
def evaluation(translation):
22+
23+
if translation == Translation.TextToGloss:
24+
translation_dir = "t2g_llama"
25+
elif translation == Translation.GlossToText:
26+
translation_dir = "g2t_llama"
27+
else:
28+
raise ValueError("Invalid translation")
29+
30+
folder_path = os.path.join("/ds/videos/AVASAG/llama_finetune/", translation_dir)
31+
merged_model_name = os.path.join(folder_path, "llama-31-it-8b-sft-merged")
32+
cache_dir = "/ds/videos/AVASAG/cache"
33+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
34+
35+
bnb_config = BitsAndBytesConfig(
36+
load_in_4bit=True,
37+
bnb_4bit_use_double_quant=True,
38+
bnb_4bit_quant_type="nf4",
39+
bnb_4bit_compute_dtype=torch.bfloat16
40+
)
41+
42+
model_finetune = AutoModelForCausalLM.from_pretrained(
43+
merged_model_name,
44+
local_files_only=True,
45+
quantization_config=bnb_config,
46+
device_map="auto"
47+
)
48+
tokenizer_finetune = AutoTokenizer.from_pretrained(
49+
merged_model_name,
50+
local_files_only=True,
51+
add_eos_token=True)
52+
53+
54+
with open(f'test_data_{translation_dir}.json', 'r') as f:
55+
test_data = json.load(f)
56+
57+
# Initialize BLEU metric
58+
bleu = BLEU()
59+
references = []
60+
predictions = []
61+
62+
# Loop through the test data and generate translations
63+
for entry in test_data:
64+
# Extract the text before and after ###>
65+
my_text = entry["text"].split("###>")[0].strip()
66+
prompt = my_text+" ###>"
67+
assert entry["text"].startswith(prompt), f"Prompt not found in the text: {entry['text']}"
68+
reference = entry["text"].split("###>")[1].strip()
69+
print("Input is:", my_text)
70+
print("Ground truth is:", reference)
71+
72+
# Tokenize and generate the translation
73+
tokenized_input = tokenizer_finetune(prompt, return_tensors="pt")
74+
input_ids = tokenized_input["input_ids"].cuda()
75+
attention_mask = tokenized_input["attention_mask"].cuda()
76+
reference_length = len(tokenizer_finetune(reference)["input_ids"]) # Get the number of tokens in reference
77+
78+
79+
# Generate the translation using the model
80+
generation_output = model_finetune.generate(
81+
input_ids=input_ids,
82+
attention_mask=attention_mask,
83+
num_beams=6,
84+
return_dict_in_generate=True,
85+
output_scores=True,
86+
max_new_tokens= reference_length
87+
)
88+
89+
# Decode the generated output
90+
for seq in generation_output.sequences:
91+
output = tokenizer_finetune.decode(seq, skip_special_tokens=True).split("###>")[1].strip()
92+
predictions.append(output)
93+
print("Generated output:", output)
94+
print("\n")
95+
96+
# Append the reference to the references list
97+
references.append([reference])
98+
99+
# Calculate BLEU score
100+
bleu_score = bleu.corpus_score(predictions, references)
101+
102+
# Print the BLEU score
103+
print(f"BLEU Score: {bleu_score.score}")
104+
105+
106+
if __name__ == "__main__":
107+
import sys
108+
109+
if len(sys.argv) != 2:
110+
print("Usage: python k_fold.py [--textTogloss|--glossTotext]")
111+
sys.exit(1)
112+
113+
if sys.argv[1] == "--textTogloss":
114+
print("Translating from Text to Gloss")
115+
translation = Translation.TextToGloss
116+
elif sys.argv[1] == "--glossTotext":
117+
print("Translating from Gloss to Text ")
118+
translation = Translation.GlossToText
119+
else:
120+
print("You have to specify either --textTogloss or --glossTotext as an argument.")
121+
sys.exit(1)
122+
123+
evaluation(translation)

0 commit comments

Comments
 (0)